Commit e2364931 authored by mashun1's avatar mashun1
Browse files

pixart-alpha

parents
Pipeline #861 canceled with stages
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalDataHed', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 1024
# model setting
model = 'PixArtMS_XL_2'
fp32_attention = False # Set to True if you got NaN loss
load_from = 'path-to-pixart-checkpoints'
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
window_block_indexes = []
window_size=0
use_rel_pos=False
lewei_scale = 2.0
# training setting
num_workers=10
train_batch_size = 4 # set the batch size according to your VRAM
num_epochs = 10 # 3
gradient_accumulation_steps = 4
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=0)
save_model_epochs=5
save_model_steps=1000
log_interval = 20
eval_sampling_steps = 200
work_dir = 'output_debug/debug'
# controlnet related params
copy_blocks_num = 13
class_dropout_prob = 0.5
train_ratio = 1
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data/dreambooth/dataset'
data = dict(type='DreamBooth', root='dog6', prompt=['a photo of sks dog'], transform='default_train', load_vae_feat=True)
image_size = 1024
# model setting
model = 'PixArtMS_XL_2' # model for multi-scale training
fp32_attention = True
load_from = 'Path/to/PixArt-XL-2-1024-MS.pth'
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
window_block_indexes = []
window_size=0
use_rel_pos=False
aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
multi_scale = True # if use multiscale dataset model training
lewei_scale = 2.0
# training setting
num_workers=1
train_batch_size = 1
num_epochs = 200
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=5e-6, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=0)
auto_lr = None
log_interval = 1
save_model_epochs=10000
save_model_steps=100
work_dir = 'output/debug'
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalDataHed', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 512
# model setting
model = 'PixArt_XL_2'
fp32_attention = False # Set to True if you got NaN loss
load_from = 'path-to-pixart-checkpoints'
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
window_block_indexes = []
window_size=0
use_rel_pos=False
lewei_scale = 1.0
# training setting
num_workers=10
train_batch_size = 12 # 32 # max 96 for DiT-L/4 when grad_checkpoint
num_epochs = 1000 # 3
gradient_accumulation_steps = 4
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=0)
save_model_epochs=5
save_model_steps=1000
log_interval = 20
eval_sampling_steps = 200
work_dir = 'output_debug/debug'
# controlnet related params
copy_blocks_num = 13
class_dropout_prob = 0.5
train_ratio = 0.1
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 1024
# model setting
window_block_indexes = []
window_size=0
use_rel_pos=False
model = 'PixArt_XL_2'
fp32_attention = True
load_from = None
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
lewei_scale = 2.0
# training setting
num_workers=10
train_batch_size = 2 # 32
num_epochs = 200 # 3
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=1000)
eval_sampling_steps = 200
log_interval = 20
save_model_epochs=1
save_model_steps=2000
work_dir = 'output/debug'
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 1024
# model setting
model = 'PixArtMS_XL_2' # model for multi-scale training
fp32_attention = True
load_from = None
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
window_block_indexes = []
window_size=0
use_rel_pos=False
aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
multi_scale = True # if use multiscale dataset model training
lewei_scale = 2.0
# training setting
num_workers=10
train_batch_size = 12 # max 14 for PixArt-xL/2 when grad_checkpoint
num_epochs = 10 # 3
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=1000)
save_model_epochs=1
save_model_steps=2000
log_interval = 20
eval_sampling_steps = 200
work_dir = 'output/debug'
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 1024
# model setting
model = 'PixArtMS_XL_2' # model for multi-scale training
fp32_attention = False # Set to True if you got NaN loss
load_from = None
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
window_block_indexes = []
window_size=0
use_rel_pos=False
aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
multi_scale = True # if use multiscale dataset model training
lewei_scale = 2.0
# training setting
num_workers=4
train_batch_size = 16 # max 12 for PixArt-xL/2 when grad_checkpoint 16 for LCM-LoRA
num_epochs = 10 # 3
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=0.0, eps=1e-10)
# optimizer = dict(type='CAMEWrapper', lr=1e-7, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
lr_schedule_args = dict(num_warmup_steps=100)
save_model_epochs=1
save_model_steps=200
valid_num=0 # take as valid aspect-ratio when sample number >= valid_num
log_interval = 10
eval_sampling_steps = 200
work_dir = 'output/debug'
# LCM
loss_type = 'huber'
huber_c = 0.001
num_ddim_timesteps=50
w_max = 15.0
w_min = 3.0
ema_decay = 0.95
cfg_scale = 4.5
class_dropout_prob = 0.
lora_rank = 32
\ No newline at end of file
_base_ = ['../PixArt_xl2_sam.py']
data_root = 'data'
# image_list_txt = ['part0.txt', 'part1.txt', 'part2.txt', 'part3.txt', 'part4.txt', 'part5.txt', 'part6.txt', 'part7.txt', 'part8.txt',
# 'part9.txt', 'part10.txt', 'part11.txt', 'part12.txt', 'part13.txt', 'part14.txt','part15.txt','part16.txt',
# 'part17.txt','part18.txt','part19.txt','part20.txt','part21.txt', 'part22.txt', 'part23.txt', 'part24.txt',
# 'part25.txt', 'part26.txt', 'part27.txt', 'part28.txt', 'part29.txt', 'part30.txt', 'part31.txt']
# data = dict(type='SAM', root='SA1B', image_list_txt=image_list_txt, transform='default_train', load_vae_feat=True)
image_list_txt = ['part0.txt']
data = dict(type='SAM', root='data_toy', image_list_txt=image_list_txt, transform='default_train', load_vae_feat=True)
image_size = 256
# model setting
window_block_indexes=[]
window_size=0
use_rel_pos=False
model = 'PixArt_XL_2'
fp32_attention = False
load_from = None
# vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
vae_pretrained = "pretrained_models/hub/pixart_alpha/sd-vae-ft-ema"
# training setting
use_fsdp=False # if use FSDP mode
num_workers=1
train_batch_size = 2 # 32
num_epochs = 2 # 3
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=1000)
eval_sampling_steps = 1
log_interval = 1
save_model_epochs=1
save_model_steps=1
work_dir = 'output/debug'
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 256
# model setting
window_block_indexes=[]
window_size=0
use_rel_pos=False
model = 'PixArt_XL_2'
fp32_attention = True
load_from = None
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
# training setting
eval_sampling_steps = 200
num_workers=10
train_batch_size = 176 # 32 # max 96 for PixArt-L/4 when grad_checkpoint
num_epochs = 200 # 3
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=1000)
log_interval = 20
save_model_epochs=5
work_dir = 'output/debug'
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 512
# model setting
window_block_indexes = []
window_size=0
use_rel_pos=False
model = 'PixArt_XL_2'
fp32_attention = True
load_from = None
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
lewei_scale = 1.0
# training setting
use_fsdp=False # if use FSDP mode
num_workers=10
train_batch_size = 38 # 32
num_epochs = 200 # 3
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=1000)
eval_sampling_steps = 200
log_interval = 20
save_model_epochs=1
work_dir = 'output/debug'
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 512
# model setting
model = 'PixArtMS_XL_2' # model for multi-scale training
fp32_attention = True
load_from = None
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
window_block_indexes = []
window_size=0
use_rel_pos=False
aspect_ratio_type = 'ASPECT_RATIO_512' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
multi_scale = True # if use multiscale dataset model training
lewei_scale = 1.0
# training setting
num_workers=10
train_batch_size = 40 # max 40 for PixArt-xL/2 when grad_checkpoint
num_epochs = 20 # 3
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=1000)
save_model_epochs=1
save_model_steps=2000
log_interval = 20
eval_sampling_steps = 200
work_dir = 'output/debug'
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
from .iddpm import IDDPM
from .dpm_solver import DPMS
from .sa_sampler import SASolverSampler
import torch
from .model import gaussian_diffusion as gd
from .model.dpm_solver import model_wrapper, DPM_Solver, NoiseScheduleVP
def DPMS(model, condition, uncondition, cfg_scale, model_type='noise', noise_schedule="linear", guidance_type='classifier-free', model_kwargs=None, diffusion_steps=1000):
if model_kwargs is None:
model_kwargs = {}
betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
## 1. Define the noise schedule.
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
## 2. Convert your discrete-time `model` to the continuous-time
## noise prediction model. Here is an example for a diffusion model
## `model` with the noise prediction type ("noise") .
model_fn = model_wrapper(
model,
noise_schedule,
model_type=model_type,
model_kwargs=model_kwargs,
guidance_type=guidance_type,
condition=condition,
unconditional_condition=uncondition,
guidance_scale=cfg_scale,
)
## 3. Define dpm-solver and sample by multistep DPM-Solver.
return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
\ No newline at end of file
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
from diffusion.model.respace import SpacedDiffusion, space_timesteps
from .model import gaussian_diffusion as gd
def IDDPM(
timestep_respacing,
noise_schedule="linear",
use_kl=False,
sigma_small=False,
predict_xstart=False,
learn_sigma=True,
pred_sigma=True,
rescale_learned_sigmas=False,
diffusion_steps=1000,
snr=False,
return_startx=False,
):
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if timestep_respacing is None or timestep_respacing == "":
timestep_respacing = [diffusion_steps]
return SpacedDiffusion(
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd.ModelMeanType.START_X if predict_xstart else gd.ModelMeanType.EPSILON
),
model_var_type=(
(gd.ModelVarType.LEARNED_RANGE if learn_sigma else (
gd.ModelVarType.FIXED_LARGE
if not sigma_small
else gd.ModelVarType.FIXED_SMALL
)
)
if pred_sigma
else None
),
loss_type=loss_type,
snr=snr,
return_startx=return_startx,
# rescale_timesteps=rescale_timesteps,
)
\ No newline at end of file
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers import ConfigMixin, SchedulerMixin
from diffusers.configuration_utils import register_to_config
from diffusers.utils import BaseOutput
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
class LCMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
denoised: Optional[torch.FloatTensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.FloatTensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[:1], alphas])
betas = 1 - alphas
return betas
class LCMScheduler(SchedulerMixin, ConfigMixin):
"""
`LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
non-Markovian guidance.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
clip_sample (`bool`, defaults to `True`):
Clip the predicted sample for numerical stability.
clip_sample_range (`float`, defaults to 1.0):
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
set_alpha_to_one (`bool`, defaults to `True`):
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the alpha value at step 0.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
dynamic_thresholding_ratio (`float`, defaults to 0.995):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
"""
# _compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
# Rescale for zero SNR
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
return sample
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
return (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.to(dtype)
return sample
def set_timesteps(self, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
"""
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.num_inference_steps = num_inference_steps
# LCM Timesteps Setting: # Linear Spacing
c = self.config.num_train_timesteps // lcm_origin_steps
lcm_origin_timesteps = np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1 # LCM Training Steps Schedule
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule
self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
def get_scalings_for_boundary_condition_discrete(self, t):
self.sigma_data = 0.5 # Default: 0.5
# By dividing 0.1: This is almost a delta function at t=0.
c_skip = self.sigma_data ** 2 / ((t / 0.1) ** 2 + self.sigma_data ** 2)
c_out = ((t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data ** 2) ** 0.5)
return c_skip, c_out
def step(
self,
model_output: torch.FloatTensor,
timeindex: int,
timestep: int,
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
) -> Union[LCMSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
The weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.FloatTensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# 1. get previous step value
prev_timeindex = timeindex + 1
if prev_timeindex < len(self.timesteps):
prev_timestep = self.timesteps[prev_timeindex]
else:
prev_timestep = timestep
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 3. Get scalings for boundary conditions
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
# 4. Different Parameterization:
parameterization = self.config.prediction_type
if parameterization == "epsilon": # noise-prediction
pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
elif parameterization == "sample": # x-prediction
pred_x0 = model_output
elif parameterization == "v_prediction": # v-prediction
pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
# 4. Denoise model output using boundary conditions
denoised = c_out * pred_x0 + c_skip * sample
# 5. Sample z ~ N(0, I), For MultiStep Inference
# Noise is not used for one-step sampling.
if len(self.timesteps) > 1:
noise = torch.randn(model_output.shape).to(model_output.device)
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
else:
prev_sample = denoised
if not return_dict:
return (prev_sample, denoised)
return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
return sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
return sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
def __len__(self):
return self.config.num_train_timesteps
from mmcv import Registry
from diffusion.model.utils import set_grad_checkpoint
MODELS = Registry('models')
def build_model(cfg, use_grad_checkpoint=False, use_fp32_attention=False, gc_step=1, **kwargs):
if isinstance(cfg, str):
cfg = dict(type=cfg)
model = MODELS.build(cfg, default_args=kwargs)
if use_grad_checkpoint:
set_grad_checkpoint(model, use_fp32_attention=use_fp32_attention, gc_step=gc_step)
return model
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
import numpy as np
import torch as th
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = next(
(
obj
for obj in (mean1, logvar1, mean2, logvar2)
if isinstance(obj, th.Tensor)
),
None,
)
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [
x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ th.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
)
def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a continuous Gaussian distribution.
:param x: the targets
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
centered_x = x - means
inv_stdv = th.exp(-log_scales)
normalized_x = centered_x * inv_stdv
return th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(
normalized_x
)
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = th.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = th.where(
x < -0.999,
log_cdf_plus,
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs
import torch
from tqdm import tqdm
class NoiseScheduleVP:
def __init__(
self,
schedule='discrete',
betas=None,
alphas_cumprod=None,
continuous_beta_0=0.1,
continuous_beta_1=20.,
dtype=torch.float32,
):
"""Create a wrapper class for the forward SDE (VP type).
***
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
***
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
log_alpha_t = self.marginal_log_mean_coeff(t)
sigma_t = self.marginal_std(t)
lambda_t = self.marginal_lambda(t)
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
t = self.inverse_lambda(lambda_t)
===============================================================
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
1. For discrete-time DPMs:
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
t_i = (i + 1) / N
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
Args:
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
**Important**: Please pay special attention for the args for `alphas_cumprod`:
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
alpha_{t_n} = \sqrt{\hat{alpha_n}},
and
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
2. For continuous-time DPMs:
We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise
schedule are the default settings in Yang Song's ScoreSDE:
Args:
beta_min: A `float` number. The smallest beta for the linear schedule.
beta_max: A `float` number. The largest beta for the linear schedule.
T: A `float` number. The ending time of the forward process.
===============================================================
Args:
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
'linear' for continuous-time DPMs.
Returns:
A wrapper object of the forward SDE (VP type).
===============================================================
Example:
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', betas=betas)
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
# For continuous-time DPMs (VPSDE), linear schedule:
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
"""
if schedule not in ['discrete', 'linear']:
raise ValueError(
f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear'"
)
self.schedule = schedule
if schedule == 'discrete':
if betas is not None:
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
else:
assert alphas_cumprod is not None
log_alphas = 0.5 * torch.log(alphas_cumprod)
self.T = 1.
self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype)
self.total_N = self.log_alpha_array.shape[1]
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
else:
self.T = 1.
self.total_N = 1000
self.beta_0 = continuous_beta_0
self.beta_1 = continuous_beta_1
def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1):
"""
For some beta schedules such as cosine schedule, the log-SNR has numerical isssues.
We clip the log-SNR near t=T within -5.1 to ensure the stability.
Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE.
"""
log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas))
lambs = log_alphas - log_sigmas
idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda)
if idx > 0:
log_alphas = log_alphas[:-idx]
return log_alphas
def marginal_log_mean_coeff(self, t):
"""
Compute log(alpha_t) of a given continuous-time label t in [0, T].
"""
if self.schedule == 'discrete':
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
self.log_alpha_array.to(t.device)).reshape((-1))
elif self.schedule == 'linear':
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
def marginal_alpha(self, t):
"""
Compute alpha_t of a given continuous-time label t in [0, T].
"""
return torch.exp(self.marginal_log_mean_coeff(t))
def marginal_std(self, t):
"""
Compute sigma_t of a given continuous-time label t in [0, T].
"""
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
def marginal_lambda(self, t):
"""
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
"""
log_mean_coeff = self.marginal_log_mean_coeff(t)
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
return log_mean_coeff - log_std
def inverse_lambda(self, lamb):
"""
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
"""
if self.schedule == 'linear':
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
Delta = self.beta_0 ** 2 + tmp
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
elif self.schedule == 'discrete':
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
torch.flip(self.t_array.to(lamb.device), [1]))
return t.reshape((-1,))
def model_wrapper(
model,
noise_schedule,
model_type="noise",
model_kwargs={},
guidance_type="uncond",
condition=None,
unconditional_condition=None,
guidance_scale=1.,
classifier_fn=None,
classifier_kwargs={},
):
"""Create a wrapper function for the noise prediction model.
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
We support four types of the diffusion model by setting `model_type`:
1. "noise": noise prediction model. (Trained by predicting noise).
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
3. "v": velocity prediction model. (Trained by predicting the velocity).
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
arXiv preprint arXiv:2202.00512 (2022).
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
arXiv preprint arXiv:2210.02303 (2022).
4. "score": marginal score function. (Trained by denoising score matching).
Note that the score function and the noise prediction model follows a simple relationship:
```
noise(x_t, t) = -sigma_t * score(x_t, t)
```
We support three types of guided sampling by DPMs by setting `guidance_type`:
1. "uncond": unconditional sampling by DPMs.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
The input `classifier_fn` has the following format:
``
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
``
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
The input `model` has the following format:
``
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
``
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
arXiv preprint arXiv:2207.12598 (2022).
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
or continuous-time labels (i.e. epsilon to T).
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
``
def model_fn(x, t_continuous) -> noise:
t_input = get_model_input_time(t_continuous)
return noise_pred(model, x, t_input, **model_kwargs)
``
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
===============================================================
Args:
model: A diffusion model with the corresponding format described above.
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
model_type: A `str`. The parameterization type of the diffusion model.
"noise" or "x_start" or "v" or "score".
model_kwargs: A `dict`. A dict for the other inputs of the model function.
guidance_type: A `str`. The type of the guidance for sampling.
"uncond" or "classifier" or "classifier-free".
condition: A pytorch tensor. The condition for the guided sampling.
Only used for "classifier" or "classifier-free" guidance type.
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
Only used for "classifier-free" guidance type.
guidance_scale: A `float`. The scale for the guided sampling.
classifier_fn: A classifier function. Only used for the classifier guidance.
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
Returns:
A noise prediction model that accepts the noised data and the continuous time as the inputs.
"""
def get_model_input_time(t_continuous):
"""
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
For continuous-time DPMs, we just use `t_continuous`.
"""
if noise_schedule.schedule == 'discrete':
return (t_continuous - 1. / noise_schedule.total_N) * 1000.
else:
return t_continuous
def noise_pred_fn(x, t_continuous, cond=None):
t_input = get_model_input_time(t_continuous)
if cond is None:
output = model(x, t_input, **model_kwargs)
else:
output = model(x, t_input, cond, **model_kwargs)
if model_type == "noise":
return output
elif model_type == "x_start":
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
elif model_type == "v":
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
elif model_type == "score":
sigma_t = noise_schedule.marginal_std(t_continuous)
return -expand_dims(sigma_t, x.dim()) * output
def cond_grad_fn(x, t_input):
"""
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
"""
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
return torch.autograd.grad(log_prob.sum(), x_in)[0]
def model_fn(x, t_continuous):
"""
The noise predicition model function that is used for DPM-Solver.
"""
if guidance_type == "uncond":
return noise_pred_fn(x, t_continuous)
elif guidance_type == "classifier":
assert classifier_fn is not None
t_input = get_model_input_time(t_continuous)
cond_grad = cond_grad_fn(x, t_input)
sigma_t = noise_schedule.marginal_std(t_continuous)
noise = noise_pred_fn(x, t_continuous)
return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
elif guidance_type == "classifier-free":
if guidance_scale == 1. or unconditional_condition is None:
return noise_pred_fn(x, t_continuous, cond=condition)
x_in = torch.cat([x] * 2)
t_in = torch.cat([t_continuous] * 2)
c_in = torch.cat([unconditional_condition, condition])
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
return noise_uncond + guidance_scale * (noise - noise_uncond)
assert model_type in ["noise", "x_start", "v", "score"]
assert guidance_type in ["uncond", "classifier", "classifier-free"]
return model_fn
class DPM_Solver:
def __init__(
self,
model_fn,
noise_schedule,
algorithm_type="dpmsolver++",
correcting_x0_fn=None,
correcting_xt_fn=None,
thresholding_max_val=1.,
dynamic_thresholding_ratio=0.995,
):
"""Construct a DPM-Solver.
We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
DPMs (such as stable-diffusion).
To support advanced algorithms in image-to-image applications, we also support corrector functions for
both x0 and xt.
Args:
model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
``
def model_fn(x, t_continuous):
return noise
``
The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
correcting_x0_fn: A `str` or a function with the following format:
```
def correcting_x0_fn(x0, t):
x0_new = ...
return x0_new
```
This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
```
x0_pred = data_pred_model(xt, t)
if correcting_x0_fn is not None:
x0_pred = correcting_x0_fn(x0_pred, t)
xt_1 = update(x0_pred, xt, t)
```
If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
correcting_xt_fn: A function with the following format:
```
def correcting_xt_fn(xt, t, step):
x_new = ...
return x_new
```
This function is to correct the intermediate samples xt at each sampling step. e.g.,
```
xt = ...
xt = correcting_xt_fn(xt, t, step)
```
thresholding_max_val: A `float`. The max value for thresholding.
Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
[1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
"""
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
self.noise_schedule = noise_schedule
assert algorithm_type in ["dpmsolver", "dpmsolver++"]
self.algorithm_type = algorithm_type
if correcting_x0_fn == "dynamic_thresholding":
self.correcting_x0_fn = self.dynamic_thresholding_fn
else:
self.correcting_x0_fn = correcting_x0_fn
self.correcting_xt_fn = correcting_xt_fn
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
self.thresholding_max_val = thresholding_max_val
def dynamic_thresholding_fn(self, x0, t):
"""
The dynamic thresholding method.
"""
dims = x0.dim()
p = self.dynamic_thresholding_ratio
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
return x0
def noise_prediction_fn(self, x, t):
"""
Return the noise prediction model.
"""
return self.model(x, t)
def data_prediction_fn(self, x, t):
"""
Return the data prediction model (with corrector).
"""
noise = self.noise_prediction_fn(x, t)
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
x0 = (x - sigma_t * noise) / alpha_t
if self.correcting_x0_fn is not None:
x0 = self.correcting_x0_fn(x0, t)
return x0
def model_fn(self, x, t):
"""
Convert the model to the noise prediction model or the data prediction model.
"""
if self.algorithm_type == "dpmsolver++":
return self.data_prediction_fn(x, t)
else:
return self.noise_prediction_fn(x, t)
def get_time_steps(self, skip_type, t_T, t_0, N, device):
"""Compute the intermediate time steps for sampling.
Args:
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
- 'logSNR': uniform logSNR for the time steps.
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
N: A `int`. The total number of the spacing of the time steps.
device: A torch device.
Returns:
A pytorch tensor of the time steps, with the shape (N + 1,).
"""
if skip_type == 'logSNR':
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
return self.noise_schedule.inverse_lambda(logSNR_steps)
elif skip_type == 'time_uniform':
return torch.linspace(t_T, t_0, N + 1).to(device)
elif skip_type == 'time_quadratic':
t_order = 2
return (
torch.linspace(
t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1
)
.pow(t_order)
.to(device)
)
else:
raise ValueError(
f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'"
)
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
"""
Get the order of each step for sampling by the singlestep DPM-Solver.
We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
- If order == 1:
We take `steps` of DPM-Solver-1 (i.e. DDIM).
- If order == 2:
- Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
- If steps % 2 == 0, we use K steps of DPM-Solver-2.
- If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
- If order == 3:
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
============================================
Args:
order: A `int`. The max order for the solver (2 or 3).
steps: A `int`. The total number of function evaluations (NFE).
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
- 'logSNR': uniform logSNR for the time steps.
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
device: A torch device.
Returns:
orders: A list of the solver order of each step.
"""
if order == 3:
K = steps // 3 + 1
if steps % 3 == 0:
orders = [3, ] * (K - 2) + [2, 1]
elif steps % 3 == 1:
orders = [3, ] * (K - 1) + [1]
else:
orders = [3, ] * (K - 1) + [2]
elif order == 2:
if steps % 2 == 0:
K = steps // 2
orders = [2, ] * K
else:
K = steps // 2 + 1
orders = [2, ] * (K - 1) + [1]
elif order == 1:
K = 1
orders = [1, ] * steps
else:
raise ValueError("'order' must be '1' or '2' or '3'.")
if skip_type == 'logSNR':
# To reproduce the results in DPM-Solver paper
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
else:
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
torch.cumsum(torch.tensor([0, ] + orders), 0).to(device)]
return timesteps_outer, orders
def denoise_to_zero_fn(self, x, s):
"""
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
"""
return self.data_prediction_fn(x, s)
def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
"""
DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (1,).
t: A pytorch tensor. The ending time, with the shape (1,).
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s`.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
ns = self.noise_schedule
dims = x.dim()
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
if self.algorithm_type == "dpmsolver++":
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = self.model_fn(x, s)
x_t = (
sigma_t / sigma_s * x
- alpha_t * phi_1 * model_s
)
else:
phi_1 = torch.expm1(h)
if model_s is None:
model_s = self.model_fn(x, s)
x_t = (
torch.exp(log_alpha_t - log_alpha_s) * x
- (sigma_t * phi_1) * model_s
)
return (x_t, {'model_s': model_s}) if return_intermediate else x_t
def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
solver_type='dpmsolver'):
"""
Singlestep solver DPM-Solver-2 from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (1,).
t: A pytorch tensor. The ending time, with the shape (1,).
r1: A `float`. The hyperparameter of the second-order solver.
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ['dpmsolver', 'taylor']:
raise ValueError(
f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}"
)
if r1 is None:
r1 = 0.5
ns = self.noise_schedule
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
lambda_s1 = lambda_s + r1 * h
s1 = ns.inverse_lambda(lambda_s1)
log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
s1), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
if self.algorithm_type == "dpmsolver++":
phi_11 = torch.expm1(-r1 * h)
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = self.model_fn(x, s)
x_s1 = (
(sigma_s1 / sigma_s) * x
- (alpha_s1 * phi_11) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
if solver_type == 'dpmsolver':
x_t = (
(sigma_t / sigma_s) * x
- (alpha_t * phi_1) * model_s
- (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
)
elif solver_type == 'taylor':
x_t = (
(sigma_t / sigma_s) * x
- (alpha_t * phi_1) * model_s
+ (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s)
)
else:
phi_11 = torch.expm1(r1 * h)
phi_1 = torch.expm1(h)
if model_s is None:
model_s = self.model_fn(x, s)
x_s1 = (
torch.exp(log_alpha_s1 - log_alpha_s) * x
- (sigma_s1 * phi_11) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
if solver_type == 'dpmsolver':
x_t = (
torch.exp(log_alpha_t - log_alpha_s) * x
- (sigma_t * phi_1) * model_s
- (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
)
elif solver_type == 'taylor':
x_t = (
torch.exp(log_alpha_t - log_alpha_s) * x
- (sigma_t * phi_1) * model_s
- (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s)
)
if return_intermediate:
return x_t, {'model_s': model_s, 'model_s1': model_s1}
else:
return x_t
def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
return_intermediate=False, solver_type='dpmsolver'):
"""
Singlestep solver DPM-Solver-3 from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (1,).
t: A pytorch tensor. The ending time, with the shape (1,).
r1: A `float`. The hyperparameter of the third-order solver.
r2: A `float`. The hyperparameter of the third-order solver.
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ['dpmsolver', 'taylor']:
raise ValueError(
f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}"
)
if r1 is None:
r1 = 1. / 3.
if r2 is None:
r2 = 2. / 3.
ns = self.noise_schedule
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
lambda_s1 = lambda_s + r1 * h
lambda_s2 = lambda_s + r2 * h
s1 = ns.inverse_lambda(lambda_s1)
s2 = ns.inverse_lambda(lambda_s2)
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
s2), ns.marginal_std(t)
alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
if self.algorithm_type == "dpmsolver++":
phi_11 = torch.expm1(-r1 * h)
phi_12 = torch.expm1(-r2 * h)
phi_1 = torch.expm1(-h)
phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
phi_2 = phi_1 / h + 1.
phi_3 = phi_2 / h - 0.5
if model_s is None:
model_s = self.model_fn(x, s)
if model_s1 is None:
x_s1 = (
(sigma_s1 / sigma_s) * x
- (alpha_s1 * phi_11) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
x_s2 = (
(sigma_s2 / sigma_s) * x
- (alpha_s2 * phi_12) * model_s
+ r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
)
model_s2 = self.model_fn(x_s2, s2)
if solver_type == 'dpmsolver':
x_t = (
(sigma_t / sigma_s) * x
- (alpha_t * phi_1) * model_s
+ (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
)
elif solver_type == 'taylor':
D1_0 = (1. / r1) * (model_s1 - model_s)
D1_1 = (1. / r2) * (model_s2 - model_s)
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
x_t = (
(sigma_t / sigma_s) * x
- (alpha_t * phi_1) * model_s
+ (alpha_t * phi_2) * D1
- (alpha_t * phi_3) * D2
)
else:
phi_11 = torch.expm1(r1 * h)
phi_12 = torch.expm1(r2 * h)
phi_1 = torch.expm1(h)
phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
phi_2 = phi_1 / h - 1.
phi_3 = phi_2 / h - 0.5
if model_s is None:
model_s = self.model_fn(x, s)
if model_s1 is None:
x_s1 = (
(torch.exp(log_alpha_s1 - log_alpha_s)) * x
- (sigma_s1 * phi_11) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
x_s2 = (
(torch.exp(log_alpha_s2 - log_alpha_s)) * x
- (sigma_s2 * phi_12) * model_s
- r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
)
model_s2 = self.model_fn(x_s2, s2)
if solver_type == 'dpmsolver':
x_t = (
(torch.exp(log_alpha_t - log_alpha_s)) * x
- (sigma_t * phi_1) * model_s
- (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
)
elif solver_type == 'taylor':
D1_0 = (1. / r1) * (model_s1 - model_s)
D1_1 = (1. / r2) * (model_s2 - model_s)
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
x_t = (
(torch.exp(log_alpha_t - log_alpha_s)) * x
- (sigma_t * phi_1) * model_s
- (sigma_t * phi_2) * D1
- (sigma_t * phi_3) * D2
)
if return_intermediate:
return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
else:
return x_t
def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
"""
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
t: A pytorch tensor. The ending time, with the shape (1,).
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ['dpmsolver', 'taylor']:
raise ValueError(
f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}"
)
ns = self.noise_schedule
model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
t_prev_0), ns.marginal_lambda(t)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0 = h_0 / h
D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
if self.algorithm_type == "dpmsolver++":
phi_1 = torch.expm1(-h)
if solver_type == 'dpmsolver':
x_t = (
(sigma_t / sigma_prev_0) * x
- (alpha_t * phi_1) * model_prev_0
- 0.5 * (alpha_t * phi_1) * D1_0
)
elif solver_type == 'taylor':
x_t = (
(sigma_t / sigma_prev_0) * x
- (alpha_t * phi_1) * model_prev_0
+ (alpha_t * (phi_1 / h + 1.)) * D1_0
)
else:
phi_1 = torch.expm1(h)
if solver_type == 'dpmsolver':
x_t = (
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
- (sigma_t * phi_1) * model_prev_0
- 0.5 * (sigma_t * phi_1) * D1_0
)
elif solver_type == 'taylor':
x_t = (
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
- (sigma_t * phi_1) * model_prev_0
- (sigma_t * (phi_1 / h - 1.)) * D1_0
)
return x_t
def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'):
"""
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
t: A pytorch tensor. The ending time, with the shape (1,).
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
ns = self.noise_schedule
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_1 = lambda_prev_1 - lambda_prev_2
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0, r1 = h_0 / h, h_1 / h
D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
D1_1 = (1. / r1) * (model_prev_1 - model_prev_2)
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1. / (r0 + r1)) * (D1_0 - D1_1)
if self.algorithm_type == "dpmsolver++":
phi_1 = torch.expm1(-h)
phi_2 = phi_1 / h + 1.
phi_3 = phi_2 / h - 0.5
return (
(sigma_t / sigma_prev_0) * x
- (alpha_t * phi_1) * model_prev_0
+ (alpha_t * phi_2) * D1
- (alpha_t * phi_3) * D2
)
else:
phi_1 = torch.expm1(h)
phi_2 = phi_1 / h - 1.
phi_3 = phi_2 / h - 0.5
return (
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
- (sigma_t * phi_1) * model_prev_0
- (sigma_t * phi_2) * D1
- (sigma_t * phi_3) * D2
)
def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None,
r2=None):
"""
Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (1,).
t: A pytorch tensor. The ending time, with the shape (1,).
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
r1: A `float`. The hyperparameter of the second-order or third-order solver.
r2: A `float`. The hyperparameter of the third-order solver.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if order == 1:
return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
elif order == 2:
return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
solver_type=solver_type, r1=r1)
elif order == 3:
return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
solver_type=solver_type, r1=r1, r2=r2)
else:
raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'):
"""
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
t: A pytorch tensor. The ending time, with the shape (1,).
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if order == 1:
return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
elif order == 2:
return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
elif order == 3:
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
else:
raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
solver_type='dpmsolver'):
"""
The adaptive step size solver based on singlestep DPM-Solver.
Args:
x: A pytorch tensor. The initial value at time `t_T`.
order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
h_init: A `float`. The initial step size (for logSNR).
atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
current time and `t_0` is less than `t_err`. The default setting is 1e-5.
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_0: A pytorch tensor. The approximated solution at time `t_0`.
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
"""
ns = self.noise_schedule
s = t_T * torch.ones((1,)).to(x)
lambda_s = ns.marginal_lambda(s)
lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
h = h_init * torch.ones_like(s).to(x)
x_prev = x
nfe = 0
if order == 2:
r1 = 0.5
lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
solver_type=solver_type,
**kwargs)
elif order == 3:
r1, r2 = 1. / 3., 2. / 3.
lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
return_intermediate=True,
solver_type=solver_type)
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
solver_type=solver_type,
**kwargs)
else:
raise ValueError(
f"For adaptive step size solver, order must be 2 or 3, got {order}"
)
while torch.abs((s - t_0)).mean() > t_err:
t = ns.inverse_lambda(lambda_s + h)
x_lower, lower_noise_kwargs = lower_update(x, s, t)
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
E = norm_fn((x_higher - x_lower) / delta).max()
if torch.all(E <= 1.):
x = x_higher
s = t
x_prev = x_lower
lambda_s = ns.marginal_lambda(s)
h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
nfe += order
print('adaptive solver nfe', nfe)
return x
def add_noise(self, x, t, noise=None):
"""
Compute the noised input xt = alpha_t * x + sigma_t * noise.
Args:
x: A `torch.Tensor` with shape `(batch_size, *shape)`.
t: A `torch.Tensor` with shape `(t_size,)`.
Returns:
xt with shape `(t_size, batch_size, *shape)`.
"""
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
if noise is None:
noise = torch.randn((t.shape[0], *x.shape), device=x.device)
x = x.reshape((-1, *x.shape))
xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
return xt.squeeze(0) if t.shape[0] == 1 else xt
def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
atol=0.0078, rtol=0.05, return_intermediate=False,
):
"""
Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
"""
t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start
t_T = self.noise_schedule.T if t_end is None else t_end
assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type,
method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero,
solver_type=solver_type,
atol=atol, rtol=rtol, return_intermediate=return_intermediate)
def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
atol=0.0078, rtol=0.05, return_intermediate=False,
):
"""
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
=====================================================
We support the following algorithms for both noise prediction model and data prediction model:
- 'singlestep':
Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
The total number of function evaluations (NFE) == `steps`.
Given a fixed NFE == `steps`, the sampling procedure is:
- If `order` == 1:
- Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
- If `order` == 2:
- Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
- If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
- If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
- If `order` == 3:
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
- If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
- If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
- If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
- 'multistep':
Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
We initialize the first `order` values by lower order multistep solvers.
Given a fixed NFE == `steps`, the sampling procedure is:
Denote K = steps.
- If `order` == 1:
- We use K steps of DPM-Solver-1 (i.e. DDIM).
- If `order` == 2:
- We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
- If `order` == 3:
- We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
- 'singlestep_fixed':
Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
- 'adaptive':
Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
(NFE) and the sample quality.
- If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
- If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
=====================================================
Some advices for choosing the algorithm:
- For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
e.g., DPM-Solver:
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
skip_type='time_uniform', method='singlestep')
e.g., DPM-Solver++:
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
skip_type='time_uniform', method='singlestep')
- For **guided sampling with large guidance scale** by DPMs:
Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
e.g.
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
skip_type='time_uniform', method='multistep')
We support three types of `skip_type`:
- 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
- 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
- 'time_quadratic': quadratic time for the time steps.
=====================================================
Args:
x: A pytorch tensor. The initial value at time `t_start`
e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
steps: A `int`. The total number of function evaluations (NFE).
t_start: A `float`. The starting time of the sampling.
If `T` is None, we use self.noise_schedule.T (default is 1.0).
t_end: A `float`. The ending time of the sampling.
If `t_end` is None, we use 1. / self.noise_schedule.total_N.
e.g. if total_N == 1000, we have `t_end` == 1e-3.
For discrete-time DPMs:
- We recommend `t_end` == 1. / self.noise_schedule.total_N.
For continuous-time DPMs:
- We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
order: A `int`. The order of DPM-Solver.
skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
for diffusion models sampling by diffusion SDEs for low-resolutional images
(such as CIFAR-10). However, we observed that such trick does not matter for
high-resolutional images. As it needs an additional NFE, we do not recommend
it for high-resolutional images.
lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
Only valid for `method=multistep` and `steps < 15`. We empirically find that
this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
(especially for steps <= 10). So we recommend to set it to be `True`.
solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
return_intermediate: A `bool`. Whether to save the xt at each step.
When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
Returns:
x_end: A pytorch tensor. The approximated solution at time `t_end`.
"""
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
if return_intermediate:
assert method in ['multistep', 'singlestep',
'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values"
if self.correcting_xt_fn is not None:
assert method in ['multistep', 'singlestep',
'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None"
device = x.device
intermediates = []
with torch.no_grad():
if method == 'adaptive':
x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
solver_type=solver_type)
elif method == 'multistep':
assert steps >= order
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
assert timesteps.shape[0] - 1 == steps
# Init the initial values.
step = 0
t = timesteps[step]
t_prev_list = [t]
model_prev_list = [self.model_fn(x, t)]
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step)
if return_intermediate:
intermediates.append(x)
# Init the first `order` values by lower order multistep DPM-Solver.
for step in range(1, order):
t = timesteps[step]
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step,
solver_type=solver_type)
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step)
if return_intermediate:
intermediates.append(x)
t_prev_list.append(t)
model_prev_list.append(self.model_fn(x, t))
# Compute the remaining values by `order`-th order multistep DPM-Solver.
for step in tqdm(range(order, steps + 1)):
t = timesteps[step]
# We only use lower order for steps < 10
if lower_order_final and steps < 10:
step_order = min(order, steps + 1 - step)
else:
step_order = order
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order,
solver_type=solver_type)
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step)
if return_intermediate:
intermediates.append(x)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = t
# We do not need to evaluate the final model value.
if step < steps:
model_prev_list[-1] = self.model_fn(x, t)
elif method in ['singlestep', 'singlestep_fixed']:
if method == 'singlestep':
timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps,
order=order,
skip_type=skip_type,
t_T=t_T, t_0=t_0,
device=device)
elif method == 'singlestep_fixed':
K = steps // order
orders = [order, ] * K
timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
for step, order in enumerate(orders):
s, t = timesteps_outer[step], timesteps_outer[step + 1]
timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order,
device=device)
lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
h = lambda_inner[-1] - lambda_inner[0]
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step)
if return_intermediate:
intermediates.append(x)
else:
raise ValueError(f"Got wrong method {method}")
if denoise_to_zero:
t = torch.ones((1,)).to(device) * t_0
x = self.denoise_to_zero_fn(x, t)
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step + 1)
if return_intermediate:
intermediates.append(x)
return (x, intermediates) if return_intermediate else x
#############################################################
# other utility functions
#############################################################
def interpolate_fn(x, xp, yp):
"""
A piecewise linear function y = f(x), using xp and yp as keypoints.
We implement f(x) in a differentiable way (i.e. applicable for autograd).
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
Args:
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
yp: PyTorch tensor with shape [C, K].
Returns:
The function values f(x), with shape [N, C].
"""
N, K = x.shape[0], xp.shape[1]
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
x_idx = torch.argmin(x_indices, dim=2)
cand_start_idx = x_idx - 1
start_idx = torch.where(
torch.eq(x_idx, 0),
torch.tensor(1, device=x.device),
torch.where(
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
),
)
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
start_idx2 = torch.where(
torch.eq(x_idx, 0),
torch.tensor(0, device=x.device),
torch.where(
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
),
)
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
return start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
def expand_dims(v, dims):
"""
Expand the tensor `v` to the dim `dims`.
Args:
`v`: a PyTorch tensor with shape [N].
`dim`: a `int`.
Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
"""
return v[(...,) + (None,) * (dims - 1)]
\ No newline at end of file
import random
import numpy as np
from tqdm import tqdm
from diffusion.model.utils import *
# ----------------------------------------------------------------------------
# Proposed EDM sampler (Algorithm 2).
def edm_sampler(
net, latents, class_labels=None, cfg_scale=None, randn_like=torch.randn_like,
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **kwargs
):
# Adjust noise levels based on what's supported by the network.
sigma_min = max(sigma_min, net.sigma_min)
sigma_max = min(sigma_max, net.sigma_max)
# Time step discretization.
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
# Main sampling loop.
x_next = latents.to(torch.float64) * t_steps[0]
for i, (t_cur, t_next) in tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:])))): # 0, ..., N-1
x_cur = x_next
# Increase noise temporarily.
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
t_hat = net.round_sigma(t_cur + gamma * t_cur)
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
# Euler step.
denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
d_cur = (x_hat - denoised) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur
# Apply 2nd order correction.
if i < num_steps - 1:
denoised = net(x_next.float(), t_next, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
d_prime = (x_next - denoised) / t_next
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
return x_next
# ----------------------------------------------------------------------------
# Generalized ablation sampler, representing the superset of all sampling
# methods discussed in the paper.
def ablation_sampler(
net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like,
num_steps=18, sigma_min=None, sigma_max=None, rho=7,
solver='heun', discretization='edm', schedule='linear', scaling='none',
epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
):
assert solver in ['euler', 'heun']
assert discretization in ['vp', 've', 'iddpm', 'edm']
assert schedule in ['vp', 've', 'linear']
assert scaling in ['vp', 'none']
# Helper functions for VP & VE noise level schedules.
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (
sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
ve_sigma = lambda t: t.sqrt()
ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
ve_sigma_inv = lambda sigma: sigma ** 2
# Select default noise level range based on the specified time step discretization.
if sigma_min is None:
vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s)
sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
if sigma_max is None:
vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1)
sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
# Adjust noise levels based on what's supported by the network.
sigma_min = max(sigma_min, net.sigma_min)
sigma_max = min(sigma_max, net.sigma_max)
# Compute corresponding betas for VP.
vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
# Define time steps in terms of noise level.
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
if discretization == 'vp':
orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
elif discretization == 've':
orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
sigma_steps = ve_sigma(orig_t_steps)
elif discretization == 'iddpm':
u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
else:
assert discretization == 'edm'
sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
# Define noise level schedule.
if schedule == 'vp':
sigma = vp_sigma(vp_beta_d, vp_beta_min)
sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
elif schedule == 've':
sigma = ve_sigma
sigma_deriv = ve_sigma_deriv
sigma_inv = ve_sigma_inv
else:
assert schedule == 'linear'
sigma = lambda t: t
sigma_deriv = lambda t: 1
sigma_inv = lambda sigma: sigma
# Define scaling schedule.
if scaling == 'vp':
s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
else:
assert scaling == 'none'
s = lambda t: 1
s_deriv = lambda t: 0
# Compute final time steps based on the corresponding noise levels.
t_steps = sigma_inv(net.round_sigma(sigma_steps))
t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
# Main sampling loop.
t_next = t_steps[0]
x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
x_cur = x_next
# Increase noise temporarily.
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(
t_hat) * S_noise * randn_like(x_cur)
# Euler step.
h = t_next - t_hat
denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to(
torch.float64)
d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(
t_hat) / sigma(t_hat) * denoised
x_prime = x_hat + alpha * h * d_cur
t_prime = t_hat + alpha * h
# Apply 2nd order correction.
if solver == 'euler' or i == num_steps - 1:
x_next = x_hat + h * d_cur
else:
assert solver == 'heun'
denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to(
torch.float64)
d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(
t_prime) * s(t_prime) / sigma(t_prime) * denoised
x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
return x_next
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
import enum
import math
import numpy as np
import torch as th
import torch.nn.functional as F
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
class ModelMeanType(enum.Enum):
"""
Which type of output the model predicts.
"""
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
START_X = enum.auto() # the model predicts x_0
EPSILON = enum.auto() # the model predicts epsilon
class ModelVarType(enum.Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
LEARNED = enum.auto()
FIXED_SMALL = enum.auto()
FIXED_LARGE = enum.auto()
LEARNED_RANGE = enum.auto()
class LossType(enum.Enum):
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
RESCALED_MSE = (
enum.auto()
) # use raw MSE loss (with RESCALED_KL when learning variances)
KL = enum.auto() # use the variational lower-bound
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
def is_vb(self):
return self in [LossType.KL, LossType.RESCALED_KL]
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
warmup_time = int(num_diffusion_timesteps * warmup_frac)
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
return betas
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
"""
This is the deprecated API for creating beta schedules.
See get_named_beta_schedule() for the new library of schedules.
"""
if beta_schedule == "quad":
betas = (
np.linspace(
beta_start ** 0.5,
beta_end ** 0.5,
num_diffusion_timesteps,
dtype=np.float64,
)
** 2
)
elif beta_schedule == "linear":
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "warmup10":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
elif beta_schedule == "warmup50":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
elif beta_schedule == "const":
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1.0 / np.linspace(
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
)
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
return get_beta_schedule(
"linear",
beta_start=scale * 0.0001,
beta_end=scale * 0.02,
num_diffusion_timesteps=num_diffusion_timesteps,
)
elif schedule_name == "squaredcos_cap_v2":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
class GaussianDiffusion:
"""
Utilities for training and sampling diffusion models.
Original ported from this codebase:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
:param betas: a 1-D numpy array of betas for each diffusion timestep,
starting at T and going to 1.
"""
def __init__(
self,
*,
betas,
model_mean_type,
model_var_type,
loss_type,
snr=False,
return_startx=False,
):
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
self.loss_type = loss_type
self.snr = snr
self.return_startx = return_startx
# Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64)
self.betas = betas
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
self.num_timesteps = int(betas.shape[0])
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1], self.posterior_variance[1:])
) if len(self.posterior_variance) > 1 else np.array([])
self.posterior_mean_coef1 = (
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
)
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
if noise is None:
noise = th.randn_like(x_start)
assert noise.shape == x_start.shape
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample. Applies before
clip_denoised.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if model_kwargs is None:
model_kwargs = {}
B, C = x.shape[:2]
assert t.shape == (B,)
model_output = model(x, t, **model_kwargs)
if isinstance(model_output, tuple):
model_output, extra = model_output
else:
extra = None
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1)
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = th.exp(model_log_variance)
elif self.model_var_type in [ModelVarType.FIXED_LARGE, ModelVarType.FIXED_SMALL]:
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so
# to get a better decoder log likelihood.
ModelVarType.FIXED_LARGE: (
np.append(self.posterior_variance[1], self.betas[1:]),
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
),
ModelVarType.FIXED_SMALL: (
self.posterior_variance,
self.posterior_log_variance_clipped,
),
}[self.model_var_type]
model_variance = _extract_into_tensor(model_variance, t, x.shape)
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
else:
model_variance = th.zeros_like(model_output)
model_log_variance = th.zeros_like(model_output)
def process_xstart(x):
if denoised_fn is not None:
x = denoised_fn(x)
return x.clamp(-1, 1) if clip_denoised else x
if self.model_mean_type == ModelMeanType.START_X:
pred_xstart = process_xstart(model_output)
else:
pred_xstart = process_xstart(
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
return {
"mean": model_mean,
"variance": model_variance,
"log_variance": model_log_variance,
"pred_xstart": pred_xstart,
"extra": extra,
}
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute the mean for the previous step, given a function cond_fn that
computes the gradient of a conditional log probability with respect to
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
condition on y.
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
"""
gradient = cond_fn(x, t, **model_kwargs)
return p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute what the p_mean_variance output would have been, should the
model's score function be conditioned by cond_fn.
See condition_mean() for details on cond_fn.
Unlike condition_mean(), this instead uses the conditioning strategy
from Song et al (2020).
"""
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
out = p_mean_var.copy()
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
return out
def p_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
):
"""
Sample x_{t-1} from the model at the given timestep.
:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- 'sample': a random sample from the model.
- 'pred_xstart': a prediction of x_0.
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
noise = th.randn_like(x)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
if cond_fn is not None:
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def p_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Generate samples from the model.
:param model: the model module.
:param shape: the shape of the samples, (N, C, H, W).
:param noise: if specified, the noise from the encoder to sample.
Should be of the same shape as `shape`.
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param device: if specified, the device to create the samples on.
If not specified, use a model parameter's device.
:param progress: if True, show a tqdm progress bar.
:return: a non-differentiable batch of samples.
"""
final = None
for sample in self.p_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
):
final = sample
return final["sample"]
def p_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Generate samples from the model and yield intermediate samples from
each timestep of diffusion.
Arguments are the same as p_sample_loop().
Returns a generator over dicts, where each dict is the return value of
p_sample().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
img = noise if noise is not None else th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
)
yield out
img = out["sample"]
def ddim_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t-1} from the model using DDIM.
Same usage as p_sample().
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma = (
eta
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
)
# Equation 12.
noise = th.randn_like(x)
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def ddim_reverse_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t+1} from the model using DDIM reverse ODE.
"""
assert eta == 0.0, "Reverse ODE only for deterministic path"
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
- out["pred_xstart"]
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
# Equation 12. reversed
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
def ddim_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
):
"""
Generate samples from the model using DDIM.
Same usage as p_sample_loop().
"""
final = None
for sample in self.ddim_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
eta=eta,
):
final = sample
return final["sample"]
def ddim_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
):
"""
Use DDIM to sample from the model and yield intermediate samples from
each timestep of DDIM.
Same usage as p_sample_loop_progressive().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
img = noise if noise is not None else th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.ddim_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
eta=eta,
)
yield out
img = out["sample"]
def _vb_terms_bpd(
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
):
"""
Get a term for the variational lower-bound.
The resulting units are bits (rather than nats, as one might expect).
This allows for comparison to other papers.
:return: a dict with the following keys:
- 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions.
"""
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)
out = self.p_mean_variance(
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
kl = normal_kl(
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
)
kl = mean_flat(kl) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
)
assert decoder_nll.shape == x_start.shape
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = th.where((t == 0), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"]}
def training_losses(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param t: a batch of timestep indices.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param noise: if specified, the specific Gaussian noise to try to remove.
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""
t = timestep
if model_kwargs is None:
model_kwargs = {}
if skip_noise:
x_t = x_start
else:
if noise is None:
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise)
terms = {}
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
terms["loss"] = self._vb_terms_bpd(
model=model,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
model_kwargs=model_kwargs,
)["output"]
if self.loss_type == LossType.RESCALED_KL:
terms["loss"] *= self.num_timesteps
elif self.loss_type in [LossType.MSE, LossType.RESCALED_MSE]:
model_output = model(x_t, t, **model_kwargs)
if isinstance(model_output, dict) and model_output.get('x', None) is not None:
output = model_output['x']
else:
output = model_output
if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
return self._extracted_from_training_losses_diffusers(x_t, output, t)
# self.model_var_type = ModelVarType.LEARNED_RANGE:4
if self.model_var_type in [
ModelVarType.LEARNED,
ModelVarType.LEARNED_RANGE,
]:
B, C = x_t.shape[:2]
assert output.shape == (B, C * 2, *x_t.shape[2:])
output, model_var_values = th.split(output, C, dim=1)
# Learn the variance using the variational bound, but don't let it affect our mean prediction.
frozen_out = th.cat([output.detach(), model_var_values], dim=1)
# vb variational bound
terms["vb"] = self._vb_terms_bpd(
model=lambda *args, r=frozen_out, **kwargs: r,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
)["output"]
if self.loss_type == LossType.RESCALED_MSE:
# Divide by 1000 for equivalence with initial implementation.
# Without a factor of 1/1000, the VB term hurts the MSE term.
terms["vb"] *= self.num_timesteps / 1000.0
target = {
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)[0],
ModelMeanType.START_X: x_start,
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
assert output.shape == target.shape == x_start.shape
if self.snr:
if self.model_mean_type == ModelMeanType.START_X:
pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
pred_startx = output
elif self.model_mean_type == ModelMeanType.EPSILON:
pred_noise = output
pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
# terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
# terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
# best
target = th.where(t > 249, noise, x_start)
output = th.where(t > 249, pred_noise, pred_startx)
loss = (target - output) ** 2
if model_kwargs.get('mask_ratio', False) and model_kwargs['mask_ratio'] > 0:
assert 'mask' in model_output
loss = F.avg_pool2d(loss.mean(dim=1), model.model.module.patch_size).flatten(1)
mask = model_output['mask']
unmask = 1 - mask
terms['mse'] = mean_flat(loss * unmask) * unmask.shape[1]/unmask.sum(1)
if model_kwargs['mask_loss_coef'] > 0:
terms['mae'] = model_kwargs['mask_loss_coef'] * mean_flat(loss * mask) * mask.shape[1]/mask.sum(1)
else:
terms["mse"] = mean_flat(loss)
terms["loss"] = terms["mse"] + terms["vb"] if "vb" in terms else terms["mse"]
if "mae" in terms:
terms["loss"] = terms["loss"] + terms["mae"]
else:
raise NotImplementedError(self.loss_type)
return terms
def training_losses_diffusers(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param t: a batch of timestep indices.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param noise: if specified, the specific Gaussian noise to try to remove.
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""
t = timestep
if model_kwargs is None:
model_kwargs = {}
if skip_noise:
x_t = x_start
else:
if noise is None:
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise)
terms = {}
if self.loss_type in [LossType.KL, LossType.RESCALED_KL]:
terms["loss"] = self._vb_terms_bpd(
model=model,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
model_kwargs=model_kwargs,
)["output"]
if self.loss_type == LossType.RESCALED_KL:
terms["loss"] *= self.num_timesteps
elif self.loss_type in [LossType.MSE, LossType.RESCALED_MSE]:
output = model(x_t, timestep=t, **model_kwargs, return_dict=False)[0]
if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
return self._extracted_from_training_losses_diffusers(x_t, output, t)
if self.model_var_type in [
ModelVarType.LEARNED,
ModelVarType.LEARNED_RANGE,
]:
B, C = x_t.shape[:2]
assert output.shape == (B, C * 2, *x_t.shape[2:])
output, model_var_values = th.split(output, C, dim=1)
# Learn the variance using the variational bound, but don't let it affect our mean prediction.
frozen_out = th.cat([output.detach(), model_var_values], dim=1)
terms["vb"] = self._vb_terms_bpd(
model=lambda *args, r=frozen_out, **kwargs: r,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
)["output"]
if self.loss_type == LossType.RESCALED_MSE:
# Divide by 1000 for equivalence with initial implementation.
# Without a factor of 1/1000, the VB term hurts the MSE term.
terms["vb"] *= self.num_timesteps / 1000.0
target = {
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)[0],
ModelMeanType.START_X: x_start,
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
assert output.shape == target.shape == x_start.shape
if self.snr:
if self.model_mean_type == ModelMeanType.START_X:
pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
pred_startx = output
elif self.model_mean_type == ModelMeanType.EPSILON:
pred_noise = output
pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
# terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
# terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
# best
target = th.where(t > 249, noise, x_start)
output = th.where(t > 249, pred_noise, pred_startx)
loss = (target - output) ** 2
terms["mse"] = mean_flat(loss)
terms["loss"] = terms["mse"] + terms["vb"] if "vb" in terms else terms["mse"]
if "mae" in terms:
terms["loss"] = terms["loss"] + terms["mae"]
else:
raise NotImplementedError(self.loss_type)
return terms
def _extracted_from_training_losses_diffusers(self, x_t, output, t):
B, C = x_t.shape[:2]
assert output.shape == (B, C * 2, *x_t.shape[2:])
output = th.split(output, C, dim=1)[0]
return output, self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output), x_t
def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
)
return mean_flat(kl_prior) / np.log(2.0)
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
"""
Compute the entire variational lower-bound, measured in bits-per-dim,
as well as other related quantities.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param clip_denoised: if True, clip denoised samples.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- total_bpd: the total variational lower-bound, per batch element.
- prior_bpd: the prior term in the lower-bound.
- vb: an [N x T] tensor of terms in the lower-bound.
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
"""
device = x_start.device
batch_size = x_start.shape[0]
vb = []
xstart_mse = []
mse = []
for t in list(range(self.num_timesteps))[::-1]:
t_batch = th.tensor([t] * batch_size, device=device)
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
# Calculate VLB term at the current timestep
with th.no_grad():
out = self._vb_terms_bpd(
model,
x_start=x_start,
x_t=x_t,
t=t_batch,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
)
vb.append(out["output"])
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
mse.append(mean_flat((eps - noise) ** 2))
vb = th.stack(vb, dim=1)
xstart_mse = th.stack(xstart_mse, dim=1)
mse = th.stack(mse, dim=1)
prior_bpd = self._prior_bpd(x_start)
total_bpd = vb.sum(dim=1) + prior_bpd
return {
"total_bpd": total_bpd,
"prior_bpd": prior_bpd,
"vb": vb,
"xstart_mse": xstart_mse,
"mse": mse,
}
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + th.zeros(broadcast_shape, device=timesteps.device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment