Commit 78bae405 authored by mashun1's avatar mashun1
Browse files

open_sora_inference

parents
Pipeline #826 canceled with stages
from .stdit import STDiT
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from einops import rearrange
from timm.models.layers import DropPath
from timm.models.vision_transformer import Mlp
from opensora.acceleration.checkpoint import auto_grad_checkpoint
from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward
from opensora.acceleration.parallel_states import get_sequence_parallel_group
from opensora.models.layers.blocks import (
Attention,
CaptionEmbedder,
MultiHeadCrossAttention,
PatchEmbed3D,
SeqParallelAttention,
SeqParallelMultiHeadCrossAttention,
T2IFinalLayer,
TimestepEmbedder,
approx_gelu,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_layernorm,
t2i_modulate,
)
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import load_checkpoint
class STDiTBlock(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
d_s=None,
d_t=None,
mlp_ratio=4.0,
drop_path=0.0,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
):
super().__init__()
self.hidden_size = hidden_size
self.enable_flashattn = enable_flashattn
self._enable_sequence_parallelism = enable_sequence_parallelism
if enable_sequence_parallelism:
self.attn_cls = SeqParallelAttention
self.mha_cls = SeqParallelMultiHeadCrossAttention
else:
self.attn_cls = Attention
self.mha_cls = MultiHeadCrossAttention
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.attn = self.attn_cls(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=enable_flashattn,
)
self.cross_attn = self.mha_cls(hidden_size, num_heads)
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
# temporal attention
self.d_s = d_s
self.d_t = d_t
if self._enable_sequence_parallelism:
sp_size = dist.get_world_size(get_sequence_parallel_group())
# make sure d_t is divisible by sp_size
assert d_t % sp_size == 0
self.d_t = d_t // sp_size
self.attn_temp = self.attn_cls(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
enable_flashattn=self.enable_flashattn,
)
def forward(self, x, y, t, mask=None, tpe=None):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + t.reshape(B, 6, -1)
).chunk(6, dim=1)
x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
# spatial branch
x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s)
x_s = self.attn(x_s)
x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=self.d_t, S=self.d_s)
x = x + self.drop_path(gate_msa * x_s)
# temporal branch
x_t = rearrange(x, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s)
if tpe is not None:
x_t = x_t + tpe
x_t = self.attn_temp(x_t)
x_t = rearrange(x_t, "(B S) T C -> B (T S) C", T=self.d_t, S=self.d_s)
x = x + self.drop_path(gate_msa * x_t)
# cross attn
x = x + self.cross_attn(x, y, mask)
# mlp
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
@MODELS.register_module()
class STDiT(nn.Module):
def __init__(
self,
input_size=(1, 32, 32),
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path=0.0,
no_temporal_pos_emb=False,
caption_channels=4096,
model_max_length=120,
dtype=torch.float32,
space_scale=1.0,
time_scale=1.0,
freeze=None,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
):
super().__init__()
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.input_size = input_size
num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
self.num_patches = num_patches
self.num_temporal = input_size[0] // patch_size[0]
self.num_spatial = num_patches // self.num_temporal
self.num_heads = num_heads
self.dtype = dtype
self.no_temporal_pos_emb = no_temporal_pos_emb
self.depth = depth
self.mlp_ratio = mlp_ratio
self.enable_flashattn = enable_flashattn
self.enable_layernorm_kernel = enable_layernorm_kernel
self.space_scale = space_scale
self.time_scale = time_scale
self.register_buffer("pos_embed", self.get_spatial_pos_embed())
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels,
hidden_size=hidden_size,
uncond_prob=class_dropout_prob,
act_layer=approx_gelu,
token_num=model_max_length,
)
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)]
self.blocks = nn.ModuleList(
[
STDiTBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=self.mlp_ratio,
drop_path=drop_path[i],
enable_flashattn=self.enable_flashattn,
enable_layernorm_kernel=self.enable_layernorm_kernel,
enable_sequence_parallelism=enable_sequence_parallelism,
d_t=self.num_temporal,
d_s=self.num_spatial,
)
for i in range(self.depth)
]
)
self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)
# init model
self.initialize_weights()
self.initialize_temporal()
if freeze is not None:
assert freeze in ["not_temporal", "text"]
if freeze == "not_temporal":
self.freeze_not_temporal()
elif freeze == "text":
self.freeze_text()
# sequence parallel related configs
self.enable_sequence_parallelism = enable_sequence_parallelism
if enable_sequence_parallelism:
self.sp_rank = dist.get_rank(get_sequence_parallel_group())
else:
self.sp_rank = None
def forward(self, x, timestep, y, mask=None):
"""
Forward pass of STDiT.
Args:
x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
timestep (torch.Tensor): diffusion time steps; of shape [B]
y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
Returns:
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
# embedding
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
# shard over the sequence dim if sp is enabled
if self.enable_sequence_parallelism:
x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
t0 = self.t_block(t) # [B, C]
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# blocks
for i, block in enumerate(self.blocks):
if i == 0:
if self.enable_sequence_parallelism:
tpe = torch.chunk(
self.pos_embed_temporal, dist.get_world_size(get_sequence_parallel_group()), dim=1
)[self.sp_rank].contiguous()
else:
tpe = self.pos_embed_temporal
else:
tpe = None
x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
if self.enable_sequence_parallelism:
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
# x.shape: [B, N, C]
# final process
x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out]
x = self.unpatchify(x) # [B, C_out, T, H, W]
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def unpatchify(self, x):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
def unpatchify_old(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_spatial_pos_embed(self, grid_size=None):
if grid_size is None:
grid_size = self.input_size[1:]
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
(grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
scale=self.space_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
scale=self.time_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def freeze_not_temporal(self):
for n, p in self.named_parameters():
if "attn_temp" not in n:
p.requires_grad = False
def freeze_text(self):
for n, p in self.named_parameters():
if "cross_attn" in n:
p.requires_grad = False
def initialize_temporal(self):
for block in self.blocks:
nn.init.constant_(block.attn_temp.proj.weight, 0)
nn.init.constant_(block.attn_temp.proj.bias, 0)
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@MODELS.register_module("STDiT-XL/2")
def STDiT_XL_2(from_pretrained=None, **kwargs):
model = STDiT(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model
from .classes import ClassEncoder
from .clip import ClipEncoder
from .t5 import T5Encoder
import torch
from opensora.registry import MODELS
@MODELS.register_module("classes")
class ClassEncoder:
def __init__(self, num_classes, model_max_length=None, device="cuda", dtype=torch.float):
self.num_classes = num_classes
self.y_embedder = None
self.model_max_length = model_max_length
self.output_dim = None
self.device = device
def encode(self, text):
return dict(y=torch.tensor([int(t) for t in text]).to(self.device))
def null(self, n):
return torch.tensor([self.num_classes] * n).to(self.device)
# Copyright 2024 Vchitect/Latte
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.# Modified from Latte
#
# This file is adapted from the Latte project.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# Latte: https://github.com/Vchitect/Latte
# DiT: https://github.com/facebookresearch/DiT/tree/main
# --------------------------------------------------------
import torch
import torch.nn as nn
import transformers
from transformers import CLIPTextModel, CLIPTokenizer
from opensora.registry import MODELS
transformers.logging.set_verbosity_error()
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, path="openai/clip-vit-huge-patch14", device="cuda", max_length=77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(path)
self.transformer = CLIPTextModel.from_pretrained(path)
self.device = device
self.max_length = max_length
self._freeze()
def _freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
pooled_z = outputs.pooler_output
return z, pooled_z
def encode(self, text):
return self(text)
@MODELS.register_module("clip")
class ClipEncoder:
"""
Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
"""
def __init__(
self,
from_pretrained,
model_max_length=77,
device="cuda",
dtype=torch.float,
):
super().__init__()
assert from_pretrained is not None, "Please specify the path to the T5 model"
self.text_encoder = FrozenCLIPEmbedder(path=from_pretrained, max_length=model_max_length).to(device, dtype)
self.y_embedder = None
self.model_max_length = model_max_length
self.output_dim = self.text_encoder.transformer.config.hidden_size
def encode(self, text):
_, pooled_embeddings = self.text_encoder.encode(text)
y = pooled_embeddings.unsqueeze(1).unsqueeze(1)
return dict(y=y)
def null(self, n):
null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
return null_y
def to(self, dtype):
self.text_encoder = self.text_encoder.to(dtype)
return self
# Adapted from PixArt
#
# Copyright (C) 2023 PixArt-alpha/PixArt-alpha
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha
# T5: https://github.com/google-research/text-to-text-transfer-transformer
# --------------------------------------------------------
import html
import os
import re
import urllib.parse as ul
import ftfy
import torch
from bs4 import BeautifulSoup
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, T5EncoderModel
from opensora.registry import MODELS
class T5Embedder:
available_models = ["t5-v1_1-xxl"]
bad_punct_regex = re.compile(
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
def __init__(
self,
device,
dir_or_name="t5-v1_1-xxl",
*,
local_cache=False,
cache_dir=None,
hf_token=None,
use_text_preprocessing=True,
t5_model_kwargs=None,
torch_dtype=None,
use_offload_folder=None,
model_max_length=120,
):
self.device = torch.device(device)
self.torch_dtype = torch_dtype or torch.bfloat16
if t5_model_kwargs is None:
t5_model_kwargs = {"low_cpu_mem_usage": True, "torch_dtype": self.torch_dtype}
if use_offload_folder is not None:
t5_model_kwargs["offload_folder"] = use_offload_folder
t5_model_kwargs["device_map"] = {
"shared": self.device,
"encoder.embed_tokens": self.device,
"encoder.block.0": self.device,
"encoder.block.1": self.device,
"encoder.block.2": self.device,
"encoder.block.3": self.device,
"encoder.block.4": self.device,
"encoder.block.5": self.device,
"encoder.block.6": self.device,
"encoder.block.7": self.device,
"encoder.block.8": self.device,
"encoder.block.9": self.device,
"encoder.block.10": self.device,
"encoder.block.11": self.device,
"encoder.block.12": "disk",
"encoder.block.13": "disk",
"encoder.block.14": "disk",
"encoder.block.15": "disk",
"encoder.block.16": "disk",
"encoder.block.17": "disk",
"encoder.block.18": "disk",
"encoder.block.19": "disk",
"encoder.block.20": "disk",
"encoder.block.21": "disk",
"encoder.block.22": "disk",
"encoder.block.23": "disk",
"encoder.final_layer_norm": "disk",
"encoder.dropout": "disk",
}
else:
t5_model_kwargs["device_map"] = {"shared": self.device, "encoder": self.device}
self.use_text_preprocessing = use_text_preprocessing
self.hf_token = hf_token
self.cache_dir = cache_dir or os.path.expanduser("~/.cache/IF_")
self.dir_or_name = dir_or_name
tokenizer_path, path = dir_or_name, dir_or_name
if local_cache:
cache_dir = os.path.join(self.cache_dir, dir_or_name)
tokenizer_path, path = cache_dir, cache_dir
elif dir_or_name in self.available_models:
cache_dir = os.path.join(self.cache_dir, dir_or_name)
for filename in [
"config.json",
"special_tokens_map.json",
"spiece.model",
"tokenizer_config.json",
"pytorch_model.bin.index.json",
"pytorch_model-00001-of-00002.bin",
"pytorch_model-00002-of-00002.bin",
]:
hf_hub_download(
repo_id=f"DeepFloyd/{dir_or_name}",
filename=filename,
cache_dir=cache_dir,
force_filename=filename,
token=self.hf_token,
)
tokenizer_path, path = cache_dir, cache_dir
else:
cache_dir = os.path.join(self.cache_dir, "t5-v1_1-xxl")
for filename in [
"config.json",
"special_tokens_map.json",
"spiece.model",
"tokenizer_config.json",
]:
hf_hub_download(
repo_id="DeepFloyd/t5-v1_1-xxl",
filename=filename,
cache_dir=cache_dir,
force_filename=filename,
token=self.hf_token,
)
tokenizer_path = cache_dir
print(tokenizer_path)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
self.model_max_length = model_max_length
def get_text_embeddings(self, texts):
texts = [self.text_preprocessing(text) for text in texts]
text_tokens_and_mask = self.tokenizer(
texts,
max_length=self.model_max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
text_tokens_and_mask["input_ids"] = text_tokens_and_mask["input_ids"]
text_tokens_and_mask["attention_mask"] = text_tokens_and_mask["attention_mask"]
with torch.no_grad():
text_encoder_embs = self.model(
input_ids=text_tokens_and_mask["input_ids"].to(self.device),
attention_mask=text_tokens_and_mask["attention_mask"].to(self.device),
)["last_hidden_state"].detach()
return text_encoder_embs, text_tokens_and_mask["attention_mask"].to(self.device)
def text_preprocessing(self, text):
if self.use_text_preprocessing:
# The exact text cleaning as was in the training stage:
text = self.clean_caption(text)
text = self.clean_caption(text)
return text
else:
return text.lower().strip()
@staticmethod
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def clean_caption(self, caption):
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[‘’]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = self.basic_clean(caption)
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
@MODELS.register_module("t5")
class T5Encoder:
def __init__(
self,
from_pretrained=None,
model_max_length=120,
device="cuda",
dtype=torch.float,
local_cache=True,
shardformer=False,
):
assert from_pretrained is not None, "Please specify the path to the T5 model"
self.t5 = T5Embedder(
device=device,
torch_dtype=dtype,
local_cache=local_cache,
cache_dir=from_pretrained,
model_max_length=model_max_length,
)
self.t5.model.to(dtype=dtype)
self.y_embedder = None
self.model_max_length = model_max_length
self.output_dim = self.t5.model.config.d_model
if shardformer:
self.shardformer_t5()
def shardformer_t5(self):
from colossalai.shardformer import ShardConfig, ShardFormer
from opensora.acceleration.shardformer.policy.t5_encoder import T5EncoderPolicy
from opensora.utils.misc import requires_grad
shard_config = ShardConfig(
tensor_parallel_process_group=None,
pipeline_stage_manager=None,
enable_tensor_parallelism=False,
enable_fused_normalization=False,
enable_flash_attention=False,
enable_jit_fused=True,
enable_sequence_parallelism=False,
enable_sequence_overlap=False,
)
shard_former = ShardFormer(shard_config=shard_config)
optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy())
self.t5.model = optim_model.half()
# ensure the weights are frozen
requires_grad(self.t5.model, False)
def encode(self, text):
caption_embs, emb_masks = self.t5.get_text_embeddings(text)
caption_embs = caption_embs[:, None]
return dict(y=caption_embs, mask=emb_masks)
def null(self, n):
null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
return null_y
from .vae import VideoAutoencoderKL, VideoAutoencoderKLTemporalDecoder
import torch
import torch.nn as nn
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from einops import rearrange
from opensora.registry import MODELS
@MODELS.register_module()
class VideoAutoencoderKL(nn.Module):
def __init__(self, from_pretrained=None, micro_batch_size=None):
super().__init__()
self.module = AutoencoderKL.from_pretrained(from_pretrained)
self.out_channels = self.module.config.latent_channels
self.patch_size = (1, 8, 8)
self.micro_batch_size = micro_batch_size
def encode(self, x):
# x: (B, C, T, H, W)
B = x.shape[0]
x = rearrange(x, "B C T H W -> (B T) C H W")
if self.micro_batch_size is None:
x = self.module.encode(x).latent_dist.sample().mul_(0.18215)
else:
bs = self.micro_batch_size
x_out = []
for i in range(0, x.shape[0], bs):
x_bs = x[i : i + bs]
x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215)
x_out.append(x_bs)
x = torch.cat(x_out, dim=0)
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
return x
def decode(self, x):
# x: (B, C, T, H, W)
B = x.shape[0]
x = rearrange(x, "B C T H W -> (B T) C H W")
if self.micro_batch_size is None:
x = self.module.decode(x / 0.18215).sample
else:
bs = self.micro_batch_size
x_out = []
for i in range(0, x.shape[0], bs):
x_bs = x[i : i + bs]
x_bs = self.module.decode(x_bs / 0.18215).sample
x_out.append(x_bs)
x = torch.cat(x_out, dim=0)
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
return x
def get_latent_size(self, input_size):
for i in range(3):
assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
input_size = [input_size[i] // self.patch_size[i] for i in range(3)]
return input_size
@MODELS.register_module()
class VideoAutoencoderKLTemporalDecoder(nn.Module):
def __init__(self, from_pretrained=None):
super().__init__()
self.module = AutoencoderKLTemporalDecoder.from_pretrained(from_pretrained)
self.out_channels = self.module.config.latent_channels
self.patch_size = (1, 8, 8)
def encode(self, x):
raise NotImplementedError
def decode(self, x):
B, _, T = x.shape[:3]
x = rearrange(x, "B C T H W -> (B T) C H W")
x = self.module.decode(x / 0.18215, num_frames=T).sample
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
return x
def get_latent_size(self, input_size):
for i in range(3):
assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
input_size = [input_size[i] // self.patch_size[i] for i in range(3)]
return input_size
from copy import deepcopy
import torch.nn as nn
from mmengine.registry import Registry
def build_module(module, builder, **kwargs):
"""Build module from config or return the module itself.
Args:
module (Union[dict, nn.Module]): The module to build.
builder (Registry): The registry to build module.
*args, **kwargs: Arguments passed to build function.
Returns:
Any: The built module.
"""
if isinstance(module, dict):
cfg = deepcopy(module)
for k, v in kwargs.items():
cfg[k] = v
return builder.build(cfg)
elif isinstance(module, nn.Module):
return module
elif module is None:
return None
else:
raise TypeError(f"Only support dict and nn.Module, but got {type(module)}.")
MODELS = Registry(
"model",
locations=["opensora.models"],
)
SCHEDULERS = Registry(
"scheduler",
locations=["opensora.schedulers"],
)
from .dpms import DPMS
from .iddpm import IDDPM
from functools import partial
import torch
from opensora.registry import SCHEDULERS
from .dpm_solver import DPMS
@SCHEDULERS.register_module("dpm-solver")
class DMP_SOLVER:
def __init__(self, num_sampling_steps=None, cfg_scale=4.0):
self.num_sampling_steps = num_sampling_steps
self.cfg_scale = cfg_scale
def sample(
self,
model,
text_encoder,
z_size,
prompts,
device,
additional_args=None,
):
n = len(prompts)
z = torch.randn(n, *z_size, device=device)
model_args = text_encoder.encode(prompts)
y = model_args.pop("y")
null_y = text_encoder.null(n)
if additional_args is not None:
model_args.update(additional_args)
dpms = DPMS(
partial(forward_with_dpmsolver, model),
condition=y,
uncondition=null_y,
cfg_scale=self.cfg_scale,
model_kwargs=model_args,
)
samples = dpms.sample(z, steps=self.num_sampling_steps, order=2, skip_type="time_uniform", method="multistep")
return samples
def forward_with_dpmsolver(self, x, timestep, y, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, **kwargs)
return model_out.chunk(2, dim=1)[0]
# MIT License
#
# Copyright (c) 2022 Cheng Lu
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
#
# This file is adapted from the dpm-solver project
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha
# dpm-solver: https://github.com/LuChengTHU/dpm-solver
# --------------------------------------------------------
import math
import numpy as np
import torch
from tqdm import tqdm
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
warmup_time = int(num_diffusion_timesteps * warmup_frac)
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
return betas
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
"""
This is the deprecated API for creating beta schedules.
See get_named_beta_schedule() for the new library of schedules.
"""
if beta_schedule == "quad":
betas = (
np.linspace(
beta_start**0.5,
beta_end**0.5,
num_diffusion_timesteps,
dtype=np.float64,
)
** 2
)
elif beta_schedule == "linear":
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "warmup10":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
elif beta_schedule == "warmup50":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
elif beta_schedule == "const":
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
return get_beta_schedule(
"linear",
beta_start=scale * 0.0001,
beta_end=scale * 0.02,
num_diffusion_timesteps=num_diffusion_timesteps,
)
elif schedule_name == "squaredcos_cap_v2":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
class NoiseScheduleVP:
def __init__(
self,
schedule="discrete",
betas=None,
alphas_cumprod=None,
continuous_beta_0=0.1,
continuous_beta_1=20.0,
dtype=torch.float32,
):
"""Create a wrapper class for the forward SDE (VP type).
***
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
***
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
log_alpha_t = self.marginal_log_mean_coeff(t)
sigma_t = self.marginal_std(t)
lambda_t = self.marginal_lambda(t)
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
t = self.inverse_lambda(lambda_t)
===============================================================
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
1. For discrete-time DPMs:
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
t_i = (i + 1) / N
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
Args:
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
**Important**: Please pay special attention for the args for `alphas_cumprod`:
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
alpha_{t_n} = \sqrt{\hat{alpha_n}},
and
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
2. For continuous-time DPMs:
We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise
schedule are the default settings in Yang Song's ScoreSDE:
Args:
beta_min: A `float` number. The smallest beta for the linear schedule.
beta_max: A `float` number. The largest beta for the linear schedule.
T: A `float` number. The ending time of the forward process.
===============================================================
Args:
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
'linear' for continuous-time DPMs.
Returns:
A wrapper object of the forward SDE (VP type).
===============================================================
Example:
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', betas=betas)
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
# For continuous-time DPMs (VPSDE), linear schedule:
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
"""
if schedule not in ["discrete", "linear"]:
raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear'")
self.schedule = schedule
if schedule == "discrete":
if betas is not None:
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
else:
assert alphas_cumprod is not None
log_alphas = 0.5 * torch.log(alphas_cumprod)
self.T = 1.0
self.log_alpha_array = (
self.numerical_clip_alpha(log_alphas)
.reshape(
(
1,
-1,
)
)
.to(dtype=dtype)
)
self.total_N = self.log_alpha_array.shape[1]
self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
else:
self.T = 1.0
self.total_N = 1000
self.beta_0 = continuous_beta_0
self.beta_1 = continuous_beta_1
def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1):
"""
For some beta schedules such as cosine schedule, the log-SNR has numerical isssues.
We clip the log-SNR near t=T within -5.1 to ensure the stability.
Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE.
"""
log_sigmas = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_alphas))
lambs = log_alphas - log_sigmas
idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda)
if idx > 0:
log_alphas = log_alphas[:-idx]
return log_alphas
def marginal_log_mean_coeff(self, t):
"""
Compute log(alpha_t) of a given continuous-time label t in [0, T].
"""
if self.schedule == "discrete":
return interpolate_fn(
t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)
).reshape((-1))
elif self.schedule == "linear":
return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
def marginal_alpha(self, t):
"""
Compute alpha_t of a given continuous-time label t in [0, T].
"""
return torch.exp(self.marginal_log_mean_coeff(t))
def marginal_std(self, t):
"""
Compute sigma_t of a given continuous-time label t in [0, T].
"""
return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
def marginal_lambda(self, t):
"""
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
"""
log_mean_coeff = self.marginal_log_mean_coeff(t)
log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
return log_mean_coeff - log_std
def inverse_lambda(self, lamb):
"""
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
"""
if self.schedule == "linear":
tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
Delta = self.beta_0**2 + tmp
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
elif self.schedule == "discrete":
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb)
t = interpolate_fn(
log_alpha.reshape((-1, 1)),
torch.flip(self.log_alpha_array.to(lamb.device), [1]),
torch.flip(self.t_array.to(lamb.device), [1]),
)
return t.reshape((-1,))
def model_wrapper(
model,
noise_schedule,
model_type="noise",
model_kwargs={},
guidance_type="uncond",
condition=None,
unconditional_condition=None,
guidance_scale=1.0,
classifier_fn=None,
classifier_kwargs={},
):
"""Create a wrapper function for the noise prediction model.
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
We support four types of the diffusion model by setting `model_type`:
1. "noise": noise prediction model. (Trained by predicting noise).
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
3. "v": velocity prediction model. (Trained by predicting the velocity).
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
arXiv preprint arXiv:2202.00512 (2022).
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
arXiv preprint arXiv:2210.02303 (2022).
4. "score": marginal score function. (Trained by denoising score matching).
Note that the score function and the noise prediction model follows a simple relationship:
```
noise(x_t, t) = -sigma_t * score(x_t, t)
```
We support three types of guided sampling by DPMs by setting `guidance_type`:
1. "uncond": unconditional sampling by DPMs.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
The input `classifier_fn` has the following format:
``
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
``
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
The input `model` has the following format:
``
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
``
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
arXiv preprint arXiv:2207.12598 (2022).
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
or continuous-time labels (i.e. epsilon to T).
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
``
def model_fn(x, t_continuous) -> noise:
t_input = get_model_input_time(t_continuous)
return noise_pred(model, x, t_input, **model_kwargs)
``
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
===============================================================
Args:
model: A diffusion model with the corresponding format described above.
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
model_type: A `str`. The parameterization type of the diffusion model.
"noise" or "x_start" or "v" or "score".
model_kwargs: A `dict`. A dict for the other inputs of the model function.
guidance_type: A `str`. The type of the guidance for sampling.
"uncond" or "classifier" or "classifier-free".
condition: A pytorch tensor. The condition for the guided sampling.
Only used for "classifier" or "classifier-free" guidance type.
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
Only used for "classifier-free" guidance type.
guidance_scale: A `float`. The scale for the guided sampling.
classifier_fn: A classifier function. Only used for the classifier guidance.
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
Returns:
A noise prediction model that accepts the noised data and the continuous time as the inputs.
"""
def get_model_input_time(t_continuous):
"""
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
For continuous-time DPMs, we just use `t_continuous`.
"""
if noise_schedule.schedule == "discrete":
return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
else:
return t_continuous
def noise_pred_fn(x, t_continuous, cond=None):
t_input = get_model_input_time(t_continuous)
if cond is None:
output = model(x, t_input, **model_kwargs)
else:
output = model(x, t_input, cond, **model_kwargs)
if model_type == "noise":
return output
elif model_type == "x_start":
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
elif model_type == "v":
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
elif model_type == "score":
sigma_t = noise_schedule.marginal_std(t_continuous)
return -expand_dims(sigma_t, x.dim()) * output
def cond_grad_fn(x, t_input):
"""
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
"""
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
return torch.autograd.grad(log_prob.sum(), x_in)[0]
def model_fn(x, t_continuous):
"""
The noise predicition model function that is used for DPM-Solver.
"""
if guidance_type == "uncond":
return noise_pred_fn(x, t_continuous)
elif guidance_type == "classifier":
assert classifier_fn is not None
t_input = get_model_input_time(t_continuous)
cond_grad = cond_grad_fn(x, t_input)
sigma_t = noise_schedule.marginal_std(t_continuous)
noise = noise_pred_fn(x, t_continuous)
return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
elif guidance_type == "classifier-free":
if guidance_scale == 1.0 or unconditional_condition is None:
return noise_pred_fn(x, t_continuous, cond=condition)
x_in = torch.cat([x] * 2)
t_in = torch.cat([t_continuous] * 2)
c_in = torch.cat([unconditional_condition, condition])
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
return noise_uncond + guidance_scale * (noise - noise_uncond)
assert model_type in ["noise", "x_start", "v", "score"]
assert guidance_type in ["uncond", "classifier", "classifier-free"]
return model_fn
class DPM_Solver:
def __init__(
self,
model_fn,
noise_schedule,
algorithm_type="dpmsolver++",
correcting_x0_fn=None,
correcting_xt_fn=None,
thresholding_max_val=1.0,
dynamic_thresholding_ratio=0.995,
):
"""Construct a DPM-Solver.
We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
DPMs (such as stable-diffusion).
To support advanced algorithms in image-to-image applications, we also support corrector functions for
both x0 and xt.
Args:
model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
``
def model_fn(x, t_continuous):
return noise
``
The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
correcting_x0_fn: A `str` or a function with the following format:
```
def correcting_x0_fn(x0, t):
x0_new = ...
return x0_new
```
This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
```
x0_pred = data_pred_model(xt, t)
if correcting_x0_fn is not None:
x0_pred = correcting_x0_fn(x0_pred, t)
xt_1 = update(x0_pred, xt, t)
```
If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
correcting_xt_fn: A function with the following format:
```
def correcting_xt_fn(xt, t, step):
x_new = ...
return x_new
```
This function is to correct the intermediate samples xt at each sampling step. e.g.,
```
xt = ...
xt = correcting_xt_fn(xt, t, step)
```
thresholding_max_val: A `float`. The max value for thresholding.
Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
[1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
"""
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
self.noise_schedule = noise_schedule
assert algorithm_type in ["dpmsolver", "dpmsolver++"]
self.algorithm_type = algorithm_type
if correcting_x0_fn == "dynamic_thresholding":
self.correcting_x0_fn = self.dynamic_thresholding_fn
else:
self.correcting_x0_fn = correcting_x0_fn
self.correcting_xt_fn = correcting_xt_fn
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
self.thresholding_max_val = thresholding_max_val
def dynamic_thresholding_fn(self, x0, t):
"""
The dynamic thresholding method.
"""
dims = x0.dim()
p = self.dynamic_thresholding_ratio
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
return x0
def noise_prediction_fn(self, x, t):
"""
Return the noise prediction model.
"""
return self.model(x, t)
def data_prediction_fn(self, x, t):
"""
Return the data prediction model (with corrector).
"""
noise = self.noise_prediction_fn(x, t)
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
x0 = (x - sigma_t * noise) / alpha_t
if self.correcting_x0_fn is not None:
x0 = self.correcting_x0_fn(x0, t)
return x0
def model_fn(self, x, t):
"""
Convert the model to the noise prediction model or the data prediction model.
"""
if self.algorithm_type == "dpmsolver++":
return self.data_prediction_fn(x, t)
else:
return self.noise_prediction_fn(x, t)
def get_time_steps(self, skip_type, t_T, t_0, N, device):
"""Compute the intermediate time steps for sampling.
Args:
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
- 'logSNR': uniform logSNR for the time steps.
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
N: A `int`. The total number of the spacing of the time steps.
device: A torch device.
Returns:
A pytorch tensor of the time steps, with the shape (N + 1,).
"""
if skip_type == "logSNR":
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
return self.noise_schedule.inverse_lambda(logSNR_steps)
elif skip_type == "time_uniform":
return torch.linspace(t_T, t_0, N + 1).to(device)
elif skip_type == "time_quadratic":
t_order = 2
return torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
else:
raise ValueError(
f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'"
)
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
"""
Get the order of each step for sampling by the singlestep DPM-Solver.
We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
- If order == 1:
We take `steps` of DPM-Solver-1 (i.e. DDIM).
- If order == 2:
- Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
- If steps % 2 == 0, we use K steps of DPM-Solver-2.
- If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
- If order == 3:
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
============================================
Args:
order: A `int`. The max order for the solver (2 or 3).
steps: A `int`. The total number of function evaluations (NFE).
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
- 'logSNR': uniform logSNR for the time steps.
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
device: A torch device.
Returns:
orders: A list of the solver order of each step.
"""
if order == 3:
K = steps // 3 + 1
if steps % 3 == 0:
orders = [
3,
] * (
K - 2
) + [2, 1]
elif steps % 3 == 1:
orders = [
3,
] * (
K - 1
) + [1]
else:
orders = [
3,
] * (
K - 1
) + [2]
elif order == 2:
if steps % 2 == 0:
K = steps // 2
orders = [
2,
] * K
else:
K = steps // 2 + 1
orders = [
2,
] * (
K - 1
) + [1]
elif order == 1:
K = 1
orders = [
1,
] * steps
else:
raise ValueError("'order' must be '1' or '2' or '3'.")
if skip_type == "logSNR":
# To reproduce the results in DPM-Solver paper
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
else:
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
torch.cumsum(
torch.tensor(
[
0,
]
+ orders
),
0,
).to(device)
]
return timesteps_outer, orders
def denoise_to_zero_fn(self, x, s):
"""
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
"""
return self.data_prediction_fn(x, s)
def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
"""
DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (1,).
t: A pytorch tensor. The ending time, with the shape (1,).
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s`.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
ns = self.noise_schedule
x.dim()
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
if self.algorithm_type == "dpmsolver++":
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = self.model_fn(x, s)
x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s
else:
phi_1 = torch.expm1(h)
if model_s is None:
model_s = self.model_fn(x, s)
x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s
return (x_t, {"model_s": model_s}) if return_intermediate else x_t
def singlestep_dpm_solver_second_update(
self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type="dpmsolver"
):
"""
Singlestep solver DPM-Solver-2 from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (1,).
t: A pytorch tensor. The ending time, with the shape (1,).
r1: A `float`. The hyperparameter of the second-order solver.
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ["dpmsolver", "taylor"]:
raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
if r1 is None:
r1 = 0.5
ns = self.noise_schedule
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
lambda_s1 = lambda_s + r1 * h
s1 = ns.inverse_lambda(lambda_s1)
log_alpha_s, log_alpha_s1, log_alpha_t = (
ns.marginal_log_mean_coeff(s),
ns.marginal_log_mean_coeff(s1),
ns.marginal_log_mean_coeff(t),
)
sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
if self.algorithm_type == "dpmsolver++":
phi_11 = torch.expm1(-r1 * h)
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = self.model_fn(x, s)
x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
model_s1 = self.model_fn(x_s1, s1)
if solver_type == "dpmsolver":
x_t = (
(sigma_t / sigma_s) * x
- (alpha_t * phi_1) * model_s
- (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
)
elif solver_type == "taylor":
x_t = (
(sigma_t / sigma_s) * x
- (alpha_t * phi_1) * model_s
+ (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s)
)
else:
phi_11 = torch.expm1(r1 * h)
phi_1 = torch.expm1(h)
if model_s is None:
model_s = self.model_fn(x, s)
x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s
model_s1 = self.model_fn(x_s1, s1)
if solver_type == "dpmsolver":
x_t = (
torch.exp(log_alpha_t - log_alpha_s) * x
- (sigma_t * phi_1) * model_s
- (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
)
elif solver_type == "taylor":
x_t = (
torch.exp(log_alpha_t - log_alpha_s) * x
- (sigma_t * phi_1) * model_s
- (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s)
)
if return_intermediate:
return x_t, {"model_s": model_s, "model_s1": model_s1}
else:
return x_t
def singlestep_dpm_solver_third_update(
self,
x,
s,
t,
r1=1.0 / 3.0,
r2=2.0 / 3.0,
model_s=None,
model_s1=None,
return_intermediate=False,
solver_type="dpmsolver",
):
"""
Singlestep solver DPM-Solver-3 from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (1,).
t: A pytorch tensor. The ending time, with the shape (1,).
r1: A `float`. The hyperparameter of the third-order solver.
r2: A `float`. The hyperparameter of the third-order solver.
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ["dpmsolver", "taylor"]:
raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
if r1 is None:
r1 = 1.0 / 3.0
if r2 is None:
r2 = 2.0 / 3.0
ns = self.noise_schedule
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
lambda_s1 = lambda_s + r1 * h
lambda_s2 = lambda_s + r2 * h
s1 = ns.inverse_lambda(lambda_s1)
s2 = ns.inverse_lambda(lambda_s2)
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = (
ns.marginal_log_mean_coeff(s),
ns.marginal_log_mean_coeff(s1),
ns.marginal_log_mean_coeff(s2),
ns.marginal_log_mean_coeff(t),
)
sigma_s, sigma_s1, sigma_s2, sigma_t = (
ns.marginal_std(s),
ns.marginal_std(s1),
ns.marginal_std(s2),
ns.marginal_std(t),
)
alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
if self.algorithm_type == "dpmsolver++":
phi_11 = torch.expm1(-r1 * h)
phi_12 = torch.expm1(-r2 * h)
phi_1 = torch.expm1(-h)
phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0
phi_2 = phi_1 / h + 1.0
phi_3 = phi_2 / h - 0.5
if model_s is None:
model_s = self.model_fn(x, s)
if model_s1 is None:
x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
model_s1 = self.model_fn(x_s1, s1)
x_s2 = (
(sigma_s2 / sigma_s) * x
- (alpha_s2 * phi_12) * model_s
+ r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
)
model_s2 = self.model_fn(x_s2, s2)
if solver_type == "dpmsolver":
x_t = (
(sigma_t / sigma_s) * x
- (alpha_t * phi_1) * model_s
+ (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
)
elif solver_type == "taylor":
D1_0 = (1.0 / r1) * (model_s1 - model_s)
D1_1 = (1.0 / r2) * (model_s2 - model_s)
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
x_t = (
(sigma_t / sigma_s) * x
- (alpha_t * phi_1) * model_s
+ (alpha_t * phi_2) * D1
- (alpha_t * phi_3) * D2
)
else:
phi_11 = torch.expm1(r1 * h)
phi_12 = torch.expm1(r2 * h)
phi_1 = torch.expm1(h)
phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0
phi_2 = phi_1 / h - 1.0
phi_3 = phi_2 / h - 0.5
if model_s is None:
model_s = self.model_fn(x, s)
if model_s1 is None:
x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s
model_s1 = self.model_fn(x_s1, s1)
x_s2 = (
(torch.exp(log_alpha_s2 - log_alpha_s)) * x
- (sigma_s2 * phi_12) * model_s
- r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
)
model_s2 = self.model_fn(x_s2, s2)
if solver_type == "dpmsolver":
x_t = (
(torch.exp(log_alpha_t - log_alpha_s)) * x
- (sigma_t * phi_1) * model_s
- (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
)
elif solver_type == "taylor":
D1_0 = (1.0 / r1) * (model_s1 - model_s)
D1_1 = (1.0 / r2) * (model_s2 - model_s)
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
x_t = (
(torch.exp(log_alpha_t - log_alpha_s)) * x
- (sigma_t * phi_1) * model_s
- (sigma_t * phi_2) * D1
- (sigma_t * phi_3) * D2
)
if return_intermediate:
return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2}
else:
return x_t
def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
"""
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
t: A pytorch tensor. The ending time, with the shape (1,).
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ["dpmsolver", "taylor"]:
raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
ns = self.noise_schedule
model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
lambda_prev_1, lambda_prev_0, lambda_t = (
ns.marginal_lambda(t_prev_1),
ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t),
)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0 = h_0 / h
D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
if self.algorithm_type == "dpmsolver++":
phi_1 = torch.expm1(-h)
if solver_type == "dpmsolver":
x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0
elif solver_type == "taylor":
x_t = (
(sigma_t / sigma_prev_0) * x
- (alpha_t * phi_1) * model_prev_0
+ (alpha_t * (phi_1 / h + 1.0)) * D1_0
)
else:
phi_1 = torch.expm1(h)
if solver_type == "dpmsolver":
x_t = (
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
- (sigma_t * phi_1) * model_prev_0
- 0.5 * (sigma_t * phi_1) * D1_0
)
elif solver_type == "taylor":
x_t = (
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
- (sigma_t * phi_1) * model_prev_0
- (sigma_t * (phi_1 / h - 1.0)) * D1_0
)
return x_t
def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
"""
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
t: A pytorch tensor. The ending time, with the shape (1,).
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
ns = self.noise_schedule
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
ns.marginal_lambda(t_prev_2),
ns.marginal_lambda(t_prev_1),
ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t),
)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_1 = lambda_prev_1 - lambda_prev_2
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0, r1 = h_0 / h, h_1 / h
D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2)
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
if self.algorithm_type == "dpmsolver++":
phi_1 = torch.expm1(-h)
phi_2 = phi_1 / h + 1.0
phi_3 = phi_2 / h - 0.5
return (
(sigma_t / sigma_prev_0) * x
- (alpha_t * phi_1) * model_prev_0
+ (alpha_t * phi_2) * D1
- (alpha_t * phi_3) * D2
)
else:
phi_1 = torch.expm1(h)
phi_2 = phi_1 / h - 1.0
phi_3 = phi_2 / h - 0.5
return (
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
- (sigma_t * phi_1) * model_prev_0
- (sigma_t * phi_2) * D1
- (sigma_t * phi_3) * D2
)
def singlestep_dpm_solver_update(
self, x, s, t, order, return_intermediate=False, solver_type="dpmsolver", r1=None, r2=None
):
"""
Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (1,).
t: A pytorch tensor. The ending time, with the shape (1,).
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
r1: A `float`. The hyperparameter of the second-order or third-order solver.
r2: A `float`. The hyperparameter of the third-order solver.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if order == 1:
return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
elif order == 2:
return self.singlestep_dpm_solver_second_update(
x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1
)
elif order == 3:
return self.singlestep_dpm_solver_third_update(
x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2
)
else:
raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"):
"""
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
t: A pytorch tensor. The ending time, with the shape (1,).
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if order == 1:
return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
elif order == 2:
return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
elif order == 3:
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
else:
raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
def dpm_solver_adaptive(
self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpmsolver"
):
"""
The adaptive step size solver based on singlestep DPM-Solver.
Args:
x: A pytorch tensor. The initial value at time `t_T`.
order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
h_init: A `float`. The initial step size (for logSNR).
atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
current time and `t_0` is less than `t_err`. The default setting is 1e-5.
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_0: A pytorch tensor. The approximated solution at time `t_0`.
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
"""
ns = self.noise_schedule
s = t_T * torch.ones((1,)).to(x)
lambda_s = ns.marginal_lambda(s)
lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
h = h_init * torch.ones_like(s).to(x)
x_prev = x
nfe = 0
if order == 2:
r1 = 0.5
lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(
x, s, t, r1=r1, solver_type=solver_type, **kwargs
)
elif order == 3:
r1, r2 = 1.0 / 3.0, 2.0 / 3.0
lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(
x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type
)
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(
x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
)
else:
raise ValueError(f"For adaptive step size solver, order must be 2 or 3, got {order}")
while torch.abs((s - t_0)).mean() > t_err:
t = ns.inverse_lambda(lambda_s + h)
x_lower, lower_noise_kwargs = lower_update(x, s, t)
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
E = norm_fn((x_higher - x_lower) / delta).max()
if torch.all(E <= 1.0):
x = x_higher
s = t
x_prev = x_lower
lambda_s = ns.marginal_lambda(s)
h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s)
nfe += order
print("adaptive solver nfe", nfe)
return x
def add_noise(self, x, t, noise=None):
"""
Compute the noised input xt = alpha_t * x + sigma_t * noise.
Args:
x: A `torch.Tensor` with shape `(batch_size, *shape)`.
t: A `torch.Tensor` with shape `(t_size,)`.
Returns:
xt with shape `(t_size, batch_size, *shape)`.
"""
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
if noise is None:
noise = torch.randn((t.shape[0], *x.shape), device=x.device)
x = x.reshape((-1, *x.shape))
xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
return xt.squeeze(0) if t.shape[0] == 1 else xt
def inverse(
self,
x,
steps=20,
t_start=None,
t_end=None,
order=2,
skip_type="time_uniform",
method="multistep",
lower_order_final=True,
denoise_to_zero=False,
solver_type="dpmsolver",
atol=0.0078,
rtol=0.05,
return_intermediate=False,
):
"""
Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
"""
t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start
t_T = self.noise_schedule.T if t_end is None else t_end
assert (
t_0 > 0 and t_T > 0
), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
return self.sample(
x,
steps=steps,
t_start=t_0,
t_end=t_T,
order=order,
skip_type=skip_type,
method=method,
lower_order_final=lower_order_final,
denoise_to_zero=denoise_to_zero,
solver_type=solver_type,
atol=atol,
rtol=rtol,
return_intermediate=return_intermediate,
)
def sample(
self,
x,
steps=20,
t_start=None,
t_end=None,
order=2,
skip_type="time_uniform",
method="multistep",
lower_order_final=True,
denoise_to_zero=False,
solver_type="dpmsolver",
atol=0.0078,
rtol=0.05,
return_intermediate=False,
):
"""
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
=====================================================
We support the following algorithms for both noise prediction model and data prediction model:
- 'singlestep':
Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
The total number of function evaluations (NFE) == `steps`.
Given a fixed NFE == `steps`, the sampling procedure is:
- If `order` == 1:
- Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
- If `order` == 2:
- Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
- If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
- If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
- If `order` == 3:
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
- If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
- If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
- If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
- 'multistep':
Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
We initialize the first `order` values by lower order multistep solvers.
Given a fixed NFE == `steps`, the sampling procedure is:
Denote K = steps.
- If `order` == 1:
- We use K steps of DPM-Solver-1 (i.e. DDIM).
- If `order` == 2:
- We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
- If `order` == 3:
- We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
- 'singlestep_fixed':
Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
- 'adaptive':
Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
(NFE) and the sample quality.
- If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
- If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
=====================================================
Some advices for choosing the algorithm:
- For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
e.g., DPM-Solver:
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
skip_type='time_uniform', method='singlestep')
e.g., DPM-Solver++:
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
skip_type='time_uniform', method='singlestep')
- For **guided sampling with large guidance scale** by DPMs:
Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
e.g.
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
skip_type='time_uniform', method='multistep')
We support three types of `skip_type`:
- 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
- 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
- 'time_quadratic': quadratic time for the time steps.
=====================================================
Args:
x: A pytorch tensor. The initial value at time `t_start`
e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
steps: A `int`. The total number of function evaluations (NFE).
t_start: A `float`. The starting time of the sampling.
If `T` is None, we use self.noise_schedule.T (default is 1.0).
t_end: A `float`. The ending time of the sampling.
If `t_end` is None, we use 1. / self.noise_schedule.total_N.
e.g. if total_N == 1000, we have `t_end` == 1e-3.
For discrete-time DPMs:
- We recommend `t_end` == 1. / self.noise_schedule.total_N.
For continuous-time DPMs:
- We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
order: A `int`. The order of DPM-Solver.
skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
for diffusion models sampling by diffusion SDEs for low-resolutional images
(such as CIFAR-10). However, we observed that such trick does not matter for
high-resolutional images. As it needs an additional NFE, we do not recommend
it for high-resolutional images.
lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
Only valid for `method=multistep` and `steps < 15`. We empirically find that
this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
(especially for steps <= 10). So we recommend to set it to be `True`.
solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
return_intermediate: A `bool`. Whether to save the xt at each step.
When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
Returns:
x_end: A pytorch tensor. The approximated solution at time `t_end`.
"""
t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
assert (
t_0 > 0 and t_T > 0
), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
if return_intermediate:
assert method in [
"multistep",
"singlestep",
"singlestep_fixed",
], "Cannot use adaptive solver when saving intermediate values"
if self.correcting_xt_fn is not None:
assert method in [
"multistep",
"singlestep",
"singlestep_fixed",
], "Cannot use adaptive solver when correcting_xt_fn is not None"
device = x.device
intermediates = []
with torch.no_grad():
if method == "adaptive":
x = self.dpm_solver_adaptive(
x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type
)
elif method == "multistep":
assert steps >= order
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
assert timesteps.shape[0] - 1 == steps
# Init the initial values.
step = 0
t = timesteps[step]
t_prev_list = [t]
model_prev_list = [self.model_fn(x, t)]
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step)
if return_intermediate:
intermediates.append(x)
# Init the first `order` values by lower order multistep DPM-Solver.
for step in range(1, order):
t = timesteps[step]
x = self.multistep_dpm_solver_update(
x, model_prev_list, t_prev_list, t, step, solver_type=solver_type
)
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step)
if return_intermediate:
intermediates.append(x)
t_prev_list.append(t)
model_prev_list.append(self.model_fn(x, t))
# Compute the remaining values by `order`-th order multistep DPM-Solver.
for step in tqdm(range(order, steps + 1)):
t = timesteps[step]
# We only use lower order for steps < 10
if lower_order_final and steps < 10:
step_order = min(order, steps + 1 - step)
else:
step_order = order
x = self.multistep_dpm_solver_update(
x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type
)
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step)
if return_intermediate:
intermediates.append(x)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = t
# We do not need to evaluate the final model value.
if step < steps:
model_prev_list[-1] = self.model_fn(x, t)
elif method in ["singlestep", "singlestep_fixed"]:
if method == "singlestep":
timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(
steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device
)
elif method == "singlestep_fixed":
K = steps // order
orders = [
order,
] * K
timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
for step, order in enumerate(orders):
s, t = timesteps_outer[step], timesteps_outer[step + 1]
timesteps_inner = self.get_time_steps(
skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device
)
lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
h = lambda_inner[-1] - lambda_inner[0]
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step)
if return_intermediate:
intermediates.append(x)
else:
raise ValueError(f"Got wrong method {method}")
if denoise_to_zero:
t = torch.ones((1,)).to(device) * t_0
x = self.denoise_to_zero_fn(x, t)
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step + 1)
if return_intermediate:
intermediates.append(x)
return (x, intermediates) if return_intermediate else x
#############################################################
# other utility functions
#############################################################
def interpolate_fn(x, xp, yp):
"""
A piecewise linear function y = f(x), using xp and yp as keypoints.
We implement f(x) in a differentiable way (i.e. applicable for autograd).
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
Args:
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
yp: PyTorch tensor with shape [C, K].
Returns:
The function values f(x), with shape [N, C].
"""
N, K = x.shape[0], xp.shape[1]
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
x_idx = torch.argmin(x_indices, dim=2)
cand_start_idx = x_idx - 1
start_idx = torch.where(
torch.eq(x_idx, 0),
torch.tensor(1, device=x.device),
torch.where(
torch.eq(x_idx, K),
torch.tensor(K - 2, device=x.device),
cand_start_idx,
),
)
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
start_idx2 = torch.where(
torch.eq(x_idx, 0),
torch.tensor(0, device=x.device),
torch.where(
torch.eq(x_idx, K),
torch.tensor(K - 2, device=x.device),
cand_start_idx,
),
)
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
return start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
def expand_dims(v, dims):
"""
Expand the tensor `v` to the dim `dims`.
Args:
`v`: a PyTorch tensor with shape [N].
`dim`: a `int`.
Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
"""
return v[(...,) + (None,) * (dims - 1)]
def DPMS(
model,
condition,
uncondition,
cfg_scale,
model_type="noise",
noise_schedule="linear",
guidance_type="classifier-free",
model_kwargs=None,
diffusion_steps=1000,
):
if model_kwargs is None:
model_kwargs = {}
betas = torch.tensor(get_named_beta_schedule(noise_schedule, diffusion_steps))
## 1. Define the noise schedule.
noise_schedule = NoiseScheduleVP(schedule="discrete", betas=betas)
## 2. Convert your discrete-time `model` to the continuous-time
## noise prediction model. Here is an example for a diffusion model
## `model` with the noise prediction type ("noise") .
model_fn = model_wrapper(
model,
noise_schedule,
model_type=model_type,
model_kwargs=model_kwargs,
guidance_type=guidance_type,
condition=condition,
unconditional_condition=uncondition,
guidance_scale=cfg_scale,
)
## 3. Define dpm-solver and sample by multistep DPM-Solver.
return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
from functools import partial
import torch
from opensora.registry import SCHEDULERS
from . import gaussian_diffusion as gd
from .respace import SpacedDiffusion, space_timesteps
@SCHEDULERS.register_module("iddpm")
class IDDPM(SpacedDiffusion):
def __init__(
self,
num_sampling_steps=None,
timestep_respacing=None,
noise_schedule="linear",
use_kl=False,
sigma_small=False,
predict_xstart=False,
learn_sigma=True,
rescale_learned_sigmas=False,
diffusion_steps=1000,
cfg_scale=4.0,
):
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if num_sampling_steps is not None:
assert timestep_respacing is None
timestep_respacing = str(num_sampling_steps)
if timestep_respacing is None or timestep_respacing == "":
timestep_respacing = [diffusion_steps]
super().__init__(
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
betas=betas,
model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X),
model_var_type=(
(gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type,
# rescale_timesteps=rescale_timesteps,
)
self.cfg_scale = cfg_scale
def sample(
self,
model,
text_encoder,
z_size,
prompts,
device,
additional_args=None,
):
n = len(prompts)
z = torch.randn(n, *z_size, device=device)
z = torch.cat([z, z], 0)
model_args = text_encoder.encode(prompts)
y_null = text_encoder.null(n)
model_args["y"] = torch.cat([model_args["y"], y_null], 0)
if additional_args is not None:
model_args.update(additional_args)
forward = partial(forward_with_cfg, model, cfg_scale=self.cfg_scale)
samples = self.p_sample_loop(
forward,
z.shape,
z,
clip_denoised=False,
model_kwargs=model_args,
progress=True,
device=device,
)
samples, _ = samples.chunk(2, dim=0)
return samples
def forward_with_cfg(model, x, timestep, y, cfg_scale, **kwargs):
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = torch.cat([half, half], dim=0)
model_out = model.forward(combined, timestep, y, **kwargs)
model_out = model_out["x"] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
# Adapted from DiT
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DiT: https://github.com/facebookresearch/DiT/tree/main
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# --------------------------------------------------------
import numpy as np
import torch as th
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, th.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)]
return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2))
def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a continuous Gaussian distribution.
:param x: the targets
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
centered_x = x - means
inv_stdv = th.exp(-log_scales)
normalized_x = centered_x * inv_stdv
log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
return log_probs
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = th.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = th.where(
x < -0.999,
log_cdf_plus,
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs
# Adapted from DiT
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DiT: https://github.com/facebookresearch/DiT/tree/main
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# --------------------------------------------------------
import enum
import math
import numpy as np
import torch as th
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
class ModelMeanType(enum.Enum):
"""
Which type of output the model predicts.
"""
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
START_X = enum.auto() # the model predicts x_0
EPSILON = enum.auto() # the model predicts epsilon
class ModelVarType(enum.Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
LEARNED = enum.auto()
FIXED_SMALL = enum.auto()
FIXED_LARGE = enum.auto()
LEARNED_RANGE = enum.auto()
class LossType(enum.Enum):
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances)
KL = enum.auto() # use the variational lower-bound
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
def is_vb(self):
return self == LossType.KL or self == LossType.RESCALED_KL
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
warmup_time = int(num_diffusion_timesteps * warmup_frac)
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
return betas
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
"""
This is the deprecated API for creating beta schedules.
See get_named_beta_schedule() for the new library of schedules.
"""
if beta_schedule == "quad":
betas = (
np.linspace(
beta_start**0.5,
beta_end**0.5,
num_diffusion_timesteps,
dtype=np.float64,
)
** 2
)
elif beta_schedule == "linear":
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "warmup10":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
elif beta_schedule == "warmup50":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
elif beta_schedule == "const":
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
return get_beta_schedule(
"linear",
beta_start=scale * 0.0001,
beta_end=scale * 0.02,
num_diffusion_timesteps=num_diffusion_timesteps,
)
elif schedule_name == "squaredcos_cap_v2":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
class GaussianDiffusion:
"""
Utilities for training and sampling diffusion models.
Original ported from this codebase:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
:param betas: a 1-D numpy array of betas for each diffusion timestep,
starting at T and going to 1.
"""
def __init__(self, *, betas, model_mean_type, model_var_type, loss_type):
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
self.loss_type = loss_type
# Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64)
self.betas = betas
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
self.num_timesteps = int(betas.shape[0])
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = (
np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
if len(self.posterior_variance) > 1
else np.array([])
)
self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
if noise is None:
noise = th.randn_like(x_start)
assert noise.shape == x_start.shape
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample. Applies before
clip_denoised.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if model_kwargs is None:
model_kwargs = {}
B, C = x.shape[:2]
assert t.shape == (B,)
model_output = model(x, t, **model_kwargs)
if isinstance(model_output, tuple):
model_output, extra = model_output
else:
extra = None
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1)
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = th.exp(model_log_variance)
else:
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so
# to get a better decoder log likelihood.
ModelVarType.FIXED_LARGE: (
np.append(self.posterior_variance[1], self.betas[1:]),
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
),
ModelVarType.FIXED_SMALL: (
self.posterior_variance,
self.posterior_log_variance_clipped,
),
}[self.model_var_type]
model_variance = _extract_into_tensor(model_variance, t, x.shape)
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
def process_xstart(x):
if denoised_fn is not None:
x = denoised_fn(x)
if clip_denoised:
return x.clamp(-1, 1)
return x
if self.model_mean_type == ModelMeanType.START_X:
pred_xstart = process_xstart(model_output)
else:
pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
return {
"mean": model_mean,
"variance": model_variance,
"log_variance": model_log_variance,
"pred_xstart": pred_xstart,
"extra": extra,
}
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute the mean for the previous step, given a function cond_fn that
computes the gradient of a conditional log probability with respect to
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
condition on y.
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
"""
gradient = cond_fn(x, t, **model_kwargs)
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
return new_mean
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute what the p_mean_variance output would have been, should the
model's score function be conditioned by cond_fn.
See condition_mean() for details on cond_fn.
Unlike condition_mean(), this instead uses the conditioning strategy
from Song et al (2020).
"""
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
out = p_mean_var.copy()
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
return out
def p_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
):
"""
Sample x_{t-1} from the model at the given timestep.
:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- 'sample': a random sample from the model.
- 'pred_xstart': a prediction of x_0.
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
noise = th.randn_like(x)
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
if cond_fn is not None:
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def p_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Generate samples from the model.
:param model: the model module.
:param shape: the shape of the samples, (N, C, H, W).
:param noise: if specified, the noise from the encoder to sample.
Should be of the same shape as `shape`.
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param device: if specified, the device to create the samples on.
If not specified, use a model parameter's device.
:param progress: if True, show a tqdm progress bar.
:return: a non-differentiable batch of samples.
"""
final = None
for sample in self.p_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
):
final = sample
return final["sample"]
def p_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Generate samples from the model and yield intermediate samples from
each timestep of diffusion.
Arguments are the same as p_sample_loop().
Returns a generator over dicts, where each dict is the return value of
p_sample().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
)
yield out
img = out["sample"]
def ddim_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t-1} from the model using DDIM.
Same usage as p_sample().
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)
# Equation 12.
noise = th.randn_like(x)
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def ddim_reverse_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t+1} from the model using DDIM reverse ODE.
"""
assert eta == 0.0, "Reverse ODE only for deterministic path"
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
# Equation 12. reversed
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
def ddim_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
):
"""
Generate samples from the model using DDIM.
Same usage as p_sample_loop().
"""
final = None
for sample in self.ddim_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
eta=eta,
):
final = sample
return final["sample"]
def ddim_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
):
"""
Use DDIM to sample from the model and yield intermediate samples from
each timestep of DDIM.
Same usage as p_sample_loop_progressive().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.ddim_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
eta=eta,
)
yield out
img = out["sample"]
def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
"""
Get a term for the variational lower-bound.
The resulting units are bits (rather than nats, as one might expect).
This allows for comparison to other papers.
:return: a dict with the following keys:
- 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions.
"""
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
kl = mean_flat(kl) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
)
assert decoder_nll.shape == x_start.shape
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = th.where((t == 0), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"]}
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param t: a batch of timestep indices.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param noise: if specified, the specific Gaussian noise to try to remove.
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""
if model_kwargs is None:
model_kwargs = {}
if noise is None:
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise)
terms = {}
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
terms["loss"] = self._vb_terms_bpd(
model=model,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
model_kwargs=model_kwargs,
)["output"]
if self.loss_type == LossType.RESCALED_KL:
terms["loss"] *= self.num_timesteps
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
model_output = model(x_t, t, **model_kwargs)
if self.model_var_type in [
ModelVarType.LEARNED,
ModelVarType.LEARNED_RANGE,
]:
B, C = x_t.shape[:2]
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
terms["vb"] = self._vb_terms_bpd(
model=lambda *args, r=frozen_out: r,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
)["output"]
if self.loss_type == LossType.RESCALED_MSE:
# Divide by 1000 for equivalence with initial implementation.
# Without a factor of 1/1000, the VB term hurts the MSE term.
terms["vb"] *= self.num_timesteps / 1000.0
target = {
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
ModelMeanType.START_X: x_start,
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
assert model_output.shape == target.shape == x_start.shape
terms["mse"] = mean_flat((target - model_output) ** 2)
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else:
terms["loss"] = terms["mse"]
else:
raise NotImplementedError(self.loss_type)
return terms
def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
return mean_flat(kl_prior) / np.log(2.0)
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
"""
Compute the entire variational lower-bound, measured in bits-per-dim,
as well as other related quantities.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param clip_denoised: if True, clip denoised samples.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- total_bpd: the total variational lower-bound, per batch element.
- prior_bpd: the prior term in the lower-bound.
- vb: an [N x T] tensor of terms in the lower-bound.
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
"""
device = x_start.device
batch_size = x_start.shape[0]
vb = []
xstart_mse = []
mse = []
for t in list(range(self.num_timesteps))[::-1]:
t_batch = th.tensor([t] * batch_size, device=device)
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
# Calculate VLB term at the current timestep
with th.no_grad():
out = self._vb_terms_bpd(
model,
x_start=x_start,
x_t=x_t,
t=t_batch,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
)
vb.append(out["output"])
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
mse.append(mean_flat((eps - noise) ** 2))
vb = th.stack(vb, dim=1)
xstart_mse = th.stack(xstart_mse, dim=1)
mse = th.stack(mse, dim=1)
prior_bpd = self._prior_bpd(x_start)
total_bpd = vb.sum(dim=1) + prior_bpd
return {
"total_bpd": total_bpd,
"prior_bpd": prior_bpd,
"vb": vb,
"xstart_mse": xstart_mse,
"mse": mse,
}
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + th.zeros(broadcast_shape, device=timesteps.device)
# Adapted from DiT
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DiT: https://github.com/facebookresearch/DiT/tree/main
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# --------------------------------------------------------
import numpy as np
import torch as th
from .gaussian_diffusion import GaussianDiffusion
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim") :])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(f"cannot divide section of {size} steps into {section_count}")
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
class SpacedDiffusion(GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps)
self.timestep_map = []
self.original_num_steps = len(kwargs["betas"])
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
kwargs["betas"] = np.array(new_betas)
super().__init__(**kwargs)
def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
return super().training_losses(self._wrap_model(model), *args, **kwargs)
def condition_mean(self, cond_fn, *args, **kwargs):
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
def condition_score(self, cond_fn, *args, **kwargs):
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
def _wrap_model(self, model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(model, self.timestep_map, self.original_num_steps)
def _scale_timesteps(self, t):
# Scaling is done by the wrapped model.
return t
class _WrappedModel:
def __init__(self, model, timestep_map, original_num_steps):
self.model = model
self.timestep_map = timestep_map
# self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs):
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts]
# if self.rescale_timesteps:
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return self.model(x, new_ts, **kwargs)
# Adapted from DiT
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DiT: https://github.com/facebookresearch/DiT/tree/main
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# --------------------------------------------------------
from abc import ABC, abstractmethod
import numpy as np
import torch as th
import torch.distributed as dist
def create_named_schedule_sampler(name, diffusion):
"""
Create a ScheduleSampler from a library of pre-defined samplers.
:param name: the name of the sampler.
:param diffusion: the diffusion object to sample for.
"""
if name == "uniform":
return UniformSampler(diffusion)
elif name == "loss-second-moment":
return LossSecondMomentResampler(diffusion)
else:
raise NotImplementedError(f"unknown schedule sampler: {name}")
class ScheduleSampler(ABC):
"""
A distribution over timesteps in the diffusion process, intended to reduce
variance of the objective.
By default, samplers perform unbiased importance sampling, in which the
objective's mean is unchanged.
However, subclasses may override sample() to change how the resampled
terms are reweighted, allowing for actual changes in the objective.
"""
@abstractmethod
def weights(self):
"""
Get a numpy array of weights, one per diffusion step.
The weights needn't be normalized, but must be positive.
"""
def sample(self, batch_size, device):
"""
Importance-sample timesteps for a batch.
:param batch_size: the number of timesteps.
:param device: the torch device to save to.
:return: a tuple (timesteps, weights):
- timesteps: a tensor of timestep indices.
- weights: a tensor of weights to scale the resulting losses.
"""
w = self.weights()
p = w / np.sum(w)
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
indices = th.from_numpy(indices_np).long().to(device)
weights_np = 1 / (len(p) * p[indices_np])
weights = th.from_numpy(weights_np).float().to(device)
return indices, weights
class UniformSampler(ScheduleSampler):
def __init__(self, diffusion):
self.diffusion = diffusion
self._weights = np.ones([diffusion.num_timesteps])
def weights(self):
return self._weights
class LossAwareSampler(ScheduleSampler):
def update_with_local_losses(self, local_ts, local_losses):
"""
Update the reweighting using losses from a model.
Call this method from each rank with a batch of timesteps and the
corresponding losses for each of those timesteps.
This method will perform synchronization to make sure all of the ranks
maintain the exact same reweighting.
:param local_ts: an integer Tensor of timesteps.
:param local_losses: a 1D Tensor of losses.
"""
batch_sizes = [th.tensor([0], dtype=th.int32, device=local_ts.device) for _ in range(dist.get_world_size())]
dist.all_gather(
batch_sizes,
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
)
# Pad all_gather batches to be the maximum batch size.
batch_sizes = [x.item() for x in batch_sizes]
max_bs = max(batch_sizes)
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
dist.all_gather(timestep_batches, local_ts)
dist.all_gather(loss_batches, local_losses)
timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]]
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
self.update_with_all_losses(timesteps, losses)
@abstractmethod
def update_with_all_losses(self, ts, losses):
"""
Update the reweighting using losses from a model.
Sub-classes should override this method to update the reweighting
using losses from the model.
This method directly updates the reweighting without synchronizing
between workers. It is called by update_with_local_losses from all
ranks with identical arguments. Thus, it should have deterministic
behavior to maintain state across workers.
:param ts: a list of int timesteps.
:param losses: a list of float losses, one per timestep.
"""
class LossSecondMomentResampler(LossAwareSampler):
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
self.diffusion = diffusion
self.history_per_term = history_per_term
self.uniform_prob = uniform_prob
self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64)
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
def weights(self):
if not self._warmed_up():
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
weights /= np.sum(weights)
weights *= 1 - self.uniform_prob
weights += self.uniform_prob / len(weights)
return weights
def update_with_all_losses(self, ts, losses):
for t, loss in zip(ts, losses):
if self._loss_counts[t] == self.history_per_term:
# Shift out the oldest loss term.
self._loss_history[t, :-1] = self._loss_history[t, 1:]
self._loss_history[t, -1] = loss
else:
self._loss_history[t, self._loss_counts[t]] = loss
self._loss_counts[t] += 1
def _warmed_up(self):
return (self._loss_counts == self.history_per_term).all()
import functools
import json
import logging
import operator
import os
from typing import Tuple
import colossalai
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.booster import Booster
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torchvision.datasets.utils import download_url
pretrained_models = {
"DiT-XL-2-512x512.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt",
"DiT-XL-2-256x256.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt",
"Latte-XL-2-256x256-ucf101.pt": "https://huggingface.co/maxin-cn/Latte/resolve/main/ucf101.pt",
"PixArt-XL-2-256x256.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-256x256.pth",
"PixArt-XL-2-SAM-256x256.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-SAM-256x256.pth",
"PixArt-XL-2-512x512.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-512x512.pth",
"PixArt-XL-2-1024-MS.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-1024-MS.pth",
}
def reparameter(ckpt, name=None):
if "DiT" in name:
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
elif "Latte" in name:
ckpt = ckpt["ema"]
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
del ckpt["temp_embed"]
elif "PixArt" in name:
ckpt = ckpt["state_dict"]
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
return ckpt
def find_model(model_name):
"""
Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
"""
if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints
model = download_model(model_name)
model = reparameter(model, model_name)
return model
else: # Load a custom DiT checkpoint:
assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}"
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
if "pos_embed_temporal" in checkpoint:
del checkpoint["pos_embed_temporal"]
if "pos_embed" in checkpoint:
del checkpoint["pos_embed"]
if "ema" in checkpoint: # supports checkpoints from train.py
checkpoint = checkpoint["ema"]
return checkpoint
def download_model(model_name):
"""
Downloads a pre-trained DiT model from the web.
"""
assert model_name in pretrained_models
local_path = f"pretrained_models/{model_name}"
if not os.path.isfile(local_path):
os.makedirs("pretrained_models", exist_ok=True)
web_path = pretrained_models[model_name]
download_url(web_path, "pretrained_models", model_name)
model = torch.load(local_path, map_location=lambda storage, loc: storage)
return model
def load_from_sharded_state_dict(model, ckpt_path):
ckpt_io = GeneralCheckpointIO()
ckpt_io.load_model(model, os.path.join(ckpt_path, "model"))
def model_sharding(model: torch.nn.Module):
global_rank = dist.get_rank()
world_size = dist.get_world_size()
for _, param in model.named_parameters():
padding_size = (world_size - param.numel() % world_size) % world_size
if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
else:
padding_param = param.data.view(-1)
splited_params = padding_param.split(padding_param.numel() // world_size)
splited_params = splited_params[global_rank]
param.data = splited_params
def load_json(file_path: str):
with open(file_path, "r") as f:
return json.load(f)
def save_json(data, file_path: str):
with open(file_path, "w") as f:
json.dump(data, f, indent=4)
def remove_padding(tensor: torch.Tensor, original_shape: Tuple) -> torch.Tensor:
return tensor[: functools.reduce(operator.mul, original_shape)]
def model_gathering(model: torch.nn.Module, model_shape_dict: dict):
global_rank = dist.get_rank()
global_size = dist.get_world_size()
for name, param in model.named_parameters():
all_params = [torch.empty_like(param.data) for _ in range(global_size)]
dist.all_gather(all_params, param.data, group=dist.group.WORLD)
if int(global_rank) == 0:
all_params = torch.cat(all_params)
param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name])
dist.barrier()
def record_model_param_shape(model: torch.nn.Module) -> dict:
param_shape = {}
for name, param in model.named_parameters():
param_shape[name] = param.shape
return param_shape
def save(
booster: Booster,
model: nn.Module,
ema: nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
epoch: int,
step: int,
global_step: int,
batch_size: int,
coordinator: DistCoordinator,
save_dir: str,
shape_dict: dict,
):
save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}")
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
# ema is not boosted, so we don't need to use booster.save_model
model_gathering(ema, shape_dict)
global_rank = dist.get_rank()
if int(global_rank) == 0:
torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt"))
model_sharding(ema)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
if lr_scheduler is not None:
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step,
"global_step": global_step,
"sample_start_index": step * batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
dist.barrier()
def load(
booster: Booster, model: nn.Module, ema: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str
) -> Tuple[int, int, int]:
booster.load_model(model, os.path.join(load_dir, "model"))
# ema is not boosted, so we don't use booster.load_model
# ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt")))
ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")))
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
if lr_scheduler is not None:
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(os.path.join(load_dir, "running_states.json"))
dist.barrier()
return running_states["epoch"], running_states["step"], running_states["sample_start_index"]
def create_logger(logging_dir):
"""
Create a logger that writes to a log file and stdout.
"""
if dist.get_rank() == 0: # real logger
logging.basicConfig(
level=logging.INFO,
format="[\033[34m%(asctime)s\033[0m] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")],
)
logger = logging.getLogger(__name__)
else: # dummy logger (does nothing)
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
return logger
def load_checkpoint(model, ckpt_path, save_as_pt=True):
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
state_dict = find_model(ckpt_path)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print(f"Missing keys: {missing_keys}")
print(f"Unexpected keys: {unexpected_keys}")
elif os.path.isdir(ckpt_path):
load_from_sharded_state_dict(model, ckpt_path)
if save_as_pt:
save_path = os.path.join(ckpt_path, "model_ckpt.pt")
torch.save(model.state_dict(), save_path)
print(f"Model checkpoint saved to {save_path}")
else:
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
import argparse
import json
import os
from glob import glob
from mmengine.config import Config
from torch.utils.tensorboard import SummaryWriter
def parse_args(training=False):
parser = argparse.ArgumentParser()
# model config
parser.add_argument("config", help="model config file path")
parser.add_argument("--seed", default=42, type=int, help="generation seed")
parser.add_argument("--ckpt-path", type=str, help="path to model ckpt; will overwrite cfg.ckpt_path if specified")
parser.add_argument("--batch-size", default=None, type=int, help="batch size")
# ======================================================
# Inference
# ======================================================
if not training:
# prompt
parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file")
parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples")
# hyperparameters
parser.add_argument("--num-sampling-steps", default=None, type=int, help="sampling steps")
parser.add_argument("--cfg-scale", default=None, type=float, help="balance between cond & uncond")
else:
parser.add_argument("--wandb", default=None, type=bool, help="enable wandb")
parser.add_argument("--load", default=None, type=str, help="path to continue training")
parser.add_argument("--data-path", default=None, type=str, help="path to data csv")
return parser.parse_args()
def merge_args(cfg, args, training=False):
if args.ckpt_path is not None:
cfg.model["from_pretrained"] = args.ckpt_path
args.ckpt_path = None
if not training:
if args.cfg_scale is not None:
cfg.scheduler["cfg_scale"] = args.cfg_scale
args.cfg_scale = None
if "multi_resolution" not in cfg:
cfg["multi_resolution"] = False
for k, v in vars(args).items():
if k in cfg and v is not None:
cfg[k] = v
return cfg
def parse_configs(training=False):
args = parse_args(training)
cfg = Config.fromfile(args.config)
cfg = merge_args(cfg, args, training)
return cfg
def create_experiment_workspace(cfg):
"""
This function creates a folder for experiment tracking.
Args:
args: The parsed arguments.
Returns:
exp_dir: The path to the experiment folder.
"""
# Make outputs folder (holds all experiment subfolders)
os.makedirs(cfg.outputs, exist_ok=True)
experiment_index = len(glob(f"{cfg.outputs}/*"))
# Create an experiment folder
model_name = cfg.model["type"].replace("/", "-")
exp_name = f"{experiment_index:03d}-F{cfg.num_frames}S{cfg.frame_interval}-{model_name}"
exp_dir = f"{cfg.outputs}/{exp_name}"
os.makedirs(exp_dir, exist_ok=True)
return exp_name, exp_dir
def save_training_config(cfg, experiment_dir):
with open(f"{experiment_dir}/config.txt", "w") as f:
json.dump(cfg, f, indent=4)
def create_tensorboard_writer(exp_dir):
tensorboard_dir = f"{exp_dir}/tensorboard"
os.makedirs(tensorboard_dir, exist_ok=True)
writer = SummaryWriter(tensorboard_dir)
return writer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment