Commit 30af93f2 authored by chenpangpang's avatar chenpangpang
Browse files

feat: gpu初始提交

parent 68e98ab8
Pipeline #2159 canceled with stages
import copy
import numpy as np
import torch
from scipy.spatial.transform import Rotation as R
def get_opencv_from_blender(matrix_world, fov, image_size):
# convert matrix_world to opencv format extrinsics
opencv_world_to_cam = matrix_world.inverse()
opencv_world_to_cam[1, :] *= -1
opencv_world_to_cam[2, :] *= -1
R, T = opencv_world_to_cam[:3, :3], opencv_world_to_cam[:3, 3]
R, T = R.unsqueeze(0), T.unsqueeze(0)
# convert fov to opencv format intrinsics
focal = 1 / np.tan(fov / 2)
intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
opencv_cam_matrix = torch.from_numpy(intrinsics).unsqueeze(0).float()
opencv_cam_matrix[:, :2, -1] += torch.tensor([image_size / 2, image_size / 2])
opencv_cam_matrix[:, [0, 1], [0, 1]] *= image_size / 2
return R, T, opencv_cam_matrix
def cartesian_to_spherical(xyz):
xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2
z = np.sqrt(xy + xyz[:, 2] ** 2)
# for elevation angle defined from z-axis down
theta = np.arctan2(np.sqrt(xy), xyz[:, 2])
azimuth = np.arctan2(xyz[:, 1], xyz[:, 0])
return np.stack([theta, azimuth, z], axis=-1)
def spherical_to_cartesian(spherical_coords):
# convert from spherical to cartesian coordinates
theta, azimuth, radius = spherical_coords.T
x = radius * np.sin(theta) * np.cos(azimuth)
y = radius * np.sin(theta) * np.sin(azimuth)
z = radius * np.cos(theta)
return np.stack([x, y, z], axis=-1)
def look_at(eye, center, up):
# Create a normalized direction vector from eye to center
f = np.array(center) - np.array(eye)
f /= np.linalg.norm(f)
# Create a normalized right vector
up_norm = np.array(up) / np.linalg.norm(up)
s = np.cross(f, up_norm)
s /= np.linalg.norm(s)
# Recompute the up vector
u = np.cross(s, f)
# Create rotation matrix R
R = np.array([[s[0], s[1], s[2]], [u[0], u[1], u[2]], [-f[0], -f[1], -f[2]]])
# Create translation vector T
T = -np.dot(R, np.array(eye))
return R, T
def get_blender_from_spherical(elevation, azimuth):
"""Generates blender camera from spherical coordinates."""
cartesian_coords = spherical_to_cartesian(np.array([[elevation, azimuth, 3.5]]))
# get camera rotation
center = np.array([0, 0, 0])
eye = cartesian_coords[0]
up = np.array([0, 0, 1])
R, T = look_at(eye, center, up)
R = R.T
T = -np.dot(R, T)
RT = np.concatenate([R, T.reshape(3, 1)], axis=-1)
blender_cam = torch.from_numpy(RT).float()
blender_cam = torch.cat([blender_cam, torch.tensor([[0, 0, 0, 1]])], dim=0)
print(blender_cam)
return blender_cam
def invert_pose(r, t):
r_inv = r.T
t_inv = -np.dot(r_inv, t)
return r_inv, t_inv
def transform_pose_sequence_to_relative(poses, as_z_up=False):
"""
poses: a sequence of 3*4 C2W camera pose matrices
as_z_up: output in z-up format. If False, the output is in y-up format
"""
r0, t0 = poses[0][:3, :3], poses[0][:3, 3]
# r0_inv, t0_inv = invert_pose(r0, t0)
r0_inv = r0.T
new_rt0 = np.hstack([np.eye(3, 3), np.zeros((3, 1))])
if as_z_up:
new_rt0 = c2w_y_up_to_z_up(new_rt0)
transformed_poses = [new_rt0]
for pose in poses[1:]:
r, t = pose[:3, :3], pose[:3, 3]
new_r = np.dot(r0_inv, r)
new_t = np.dot(r0_inv, t - t0)
new_rt = np.hstack([new_r, new_t[:, None]])
if as_z_up:
new_rt = c2w_y_up_to_z_up(new_rt)
transformed_poses.append(new_rt)
return transformed_poses
def c2w_y_up_to_z_up(c2w_3x4):
R_y_up_to_z_up = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
R = c2w_3x4[:, :3]
t = c2w_3x4[:, 3]
R_z_up = R_y_up_to_z_up @ R
t_z_up = R_y_up_to_z_up @ t
T_z_up = np.hstack((R_z_up, t_z_up.reshape(3, 1)))
return T_z_up
def transform_pose_sequence_to_relative_w2c(poses):
new_rt_list = []
first_frame_rt = copy.deepcopy(poses[0])
first_frame_r_inv = first_frame_rt[:, :3].T
first_frame_t = first_frame_rt[:, -1]
for rt in poses:
rt[:, :3] = np.matmul(rt[:, :3], first_frame_r_inv)
rt[:, -1] = rt[:, -1] - np.matmul(rt[:, :3], first_frame_t)
new_rt_list.append(copy.deepcopy(rt))
return new_rt_list
def transform_pose_sequence_to_relative_c2w(poses):
first_frame_rt = poses[0]
first_frame_r_inv = first_frame_rt[:, :3].T
first_frame_t = first_frame_rt[:, -1]
rotations = poses[:, :, :3]
translations = poses[:, :, 3]
# Compute new rotations and translations in batch
new_rotations = torch.matmul(first_frame_r_inv, rotations)
new_translations = torch.matmul(
first_frame_r_inv, (translations - first_frame_t.unsqueeze(0)).unsqueeze(-1)
)
# Concatenate new rotations and translations
new_rt = torch.cat([new_rotations, new_translations], dim=-1)
return new_rt
def convert_w2c_between_c2w(poses):
rotations = poses[:, :, :3]
translations = poses[:, :, 3]
new_rotations = rotations.transpose(-1, -2)
new_translations = torch.matmul(-new_rotations, translations.unsqueeze(-1))
new_rt = torch.cat([new_rotations, new_translations], dim=-1)
return new_rt
def slerp(q1, q2, t):
"""
Performs spherical linear interpolation (SLERP) between two quaternions.
Args:
q1 (torch.Tensor): Start quaternion (4,).
q2 (torch.Tensor): End quaternion (4,).
t (float or torch.Tensor): Interpolation parameter in [0, 1].
Returns:
torch.Tensor: Interpolated quaternion (4,).
"""
q1 = q1 / torch.linalg.norm(q1) # Normalize q1
q2 = q2 / torch.linalg.norm(q2) # Normalize q2
dot = torch.dot(q1, q2)
# Ensure shortest path (flip q2 if needed)
if dot < 0.0:
q2 = -q2
dot = -dot
# Avoid numerical precision issues
dot = torch.clamp(dot, -1.0, 1.0)
theta = torch.acos(dot) # Angle between q1 and q2
if theta < 1e-6: # If very close, use linear interpolation
return (1 - t) * q1 + t * q2
sin_theta = torch.sin(theta)
return (torch.sin((1 - t) * theta) / sin_theta) * q1 + (
torch.sin(t * theta) / sin_theta
) * q2
def interpolate_camera_poses(c2w: torch.Tensor, factor: int) -> torch.Tensor:
"""
Interpolates a sequence of camera c2w poses to N times the length of the original sequence.
Args:
c2w (torch.Tensor): Input camera poses of shape (N, 3, 4).
factor (int): The upsampling factor (e.g., 2 for doubling the length).
Returns:
torch.Tensor: Interpolated camera poses of shape (N * factor, 3, 4).
"""
assert c2w.ndim == 3 and c2w.shape[1:] == (
3,
4,
), "Input tensor must have shape (N, 3, 4)."
assert factor > 1, "Upsampling factor must be greater than 1."
N = c2w.shape[0]
new_length = N * factor
# Extract rotations (R) and translations (T)
rotations = c2w[:, :3, :3] # Shape (N, 3, 3)
translations = c2w[:, :3, 3] # Shape (N, 3)
# Convert rotations to quaternions for interpolation
quaternions = torch.tensor(
R.from_matrix(rotations.numpy()).as_quat()
) # Shape (N, 4)
# Initialize interpolated quaternions and translations
interpolated_quats = []
interpolated_translations = []
# Perform interpolation
for i in range(N - 1):
# Start and end quaternions and translations for this segment
q1, q2 = quaternions[i], quaternions[i + 1]
t1, t2 = translations[i], translations[i + 1]
# Time steps for interpolation within this segment
t_values = torch.linspace(0, 1, factor, dtype=torch.float32)
# Interpolate quaternions using SLERP
for t in t_values:
interpolated_quats.append(slerp(q1, q2, t))
# Interpolate translations linearly
interp_t = t1 * (1 - t_values[:, None]) + t2 * t_values[:, None]
interpolated_translations.append(interp_t)
interpolated_quats.append(quaternions[0])
interpolated_translations.append(translations[0].unsqueeze(0))
# Add the last pose (end of sequence)
interpolated_quats.append(quaternions[-1])
interpolated_translations.append(translations[-1].unsqueeze(0)) # Add as 2D tensor
# Combine interpolated results
interpolated_quats = torch.stack(interpolated_quats, dim=0) # Shape (new_length, 4)
interpolated_translations = torch.cat(
interpolated_translations, dim=0
) # Shape (new_length, 3)
# Convert quaternions back to rotation matrices
interpolated_rotations = torch.tensor(
R.from_quat(interpolated_quats.numpy()).as_matrix()
) # Shape (new_length, 3, 3)
# Form final c2w matrix
interpolated_c2w = torch.zeros((new_length, 3, 4), dtype=torch.float32)
interpolated_c2w[:, :3, :3] = interpolated_rotations
interpolated_c2w[:, :3, 3] = interpolated_translations
return interpolated_c2w
import PIL
import numpy as np
import torch
from PIL import Image
from .camera_pose_utils import (
convert_w2c_between_c2w,
transform_pose_sequence_to_relative_c2w,
)
def get_ray_embeddings(
poses, size_h=256, size_w=256, fov_xy_list=None, focal_xy_list=None
):
"""
poses: sequence of cameras poses (y-up format)
"""
use_focal = False
if fov_xy_list is None or fov_xy_list[0] is None or fov_xy_list[0][0] is None:
assert focal_xy_list is not None
use_focal = True
rays_embeddings = []
for i in range(poses.shape[0]):
cur_pose = poses[i]
if use_focal:
rays_o, rays_d = get_rays(
# [h, w, 3]
cur_pose,
size_h,
size_w,
focal_xy=focal_xy_list[i],
)
else:
rays_o, rays_d = get_rays(
cur_pose, size_h, size_w, fov_xy=fov_xy_list[i]
) # [h, w, 3]
rays_plucker = torch.cat(
[torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1
) # [h, w, 6]
rays_embeddings.append(rays_plucker)
rays_embeddings = (
torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous()
) # [V, 6, h, w]
return rays_embeddings
def get_rays(pose, h, w, fov_xy=None, focal_xy=None, opengl=True):
x, y = torch.meshgrid(
torch.arange(w, device=pose.device),
torch.arange(h, device=pose.device),
indexing="xy",
)
x = x.flatten()
y = y.flatten()
cx = w * 0.5
cy = h * 0.5
# print("fov_xy=", fov_xy)
# print("focal_xy=", focal_xy)
if focal_xy is None:
assert fov_xy is not None, "fov_x/y and focal_x/y cannot both be None."
focal_x = w * 0.5 / np.tan(0.5 * np.deg2rad(fov_xy[0]))
focal_y = h * 0.5 / np.tan(0.5 * np.deg2rad(fov_xy[1]))
else:
assert (
len(focal_xy) == 2
), "focal_xy should be a list-like object containing only two elements (focal length in x and y direction)."
focal_x = w * focal_xy[0]
focal_y = h * focal_xy[1]
camera_dirs = torch.nn.functional.pad(
torch.stack(
[
(x - cx + 0.5) / focal_x,
(y - cy + 0.5) / focal_y * (-1.0 if opengl else 1.0),
],
dim=-1,
),
(0, 1),
value=(-1.0 if opengl else 1.0),
) # [hw, 3]
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
rays_o = rays_o.view(h, w, 3)
rays_d = safe_normalize(rays_d).view(h, w, 3)
return rays_o, rays_d
def safe_normalize(x, eps=1e-20):
return x / length(x, eps)
def length(x, eps=1e-20):
if isinstance(x, np.ndarray):
return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
else:
return torch.sqrt(torch.clamp(dot(x, x), min=eps))
def dot(x, y):
if isinstance(x, np.ndarray):
return np.sum(x * y, -1, keepdims=True)
else:
return torch.sum(x * y, -1, keepdim=True)
def extend_list_by_repeating(original_list, target_length, repeat_idx, at_front):
if not original_list:
raise ValueError("The original list cannot be empty.")
extended_list = []
original_length = len(original_list)
for i in range(target_length - original_length):
extended_list.append(original_list[repeat_idx])
if at_front:
extended_list.extend(original_list)
return extended_list
else:
original_list.extend(extended_list)
return original_list
def select_evenly_spaced_elements(arr, x):
if x <= 0 or len(arr) == 0:
return []
# Calculate step size as the ratio of length of the list and x
step = len(arr) / x
# Pick elements at indices that are multiples of step (round them to nearest integer)
selected_elements = [arr[round(i * step)] for i in range(x)]
return selected_elements
def convert_co3d_annotation_to_opengl_pose_and_intrinsics(frame_annotation):
p = frame_annotation.viewpoint.principal_point
f = frame_annotation.viewpoint.focal_length
h, w = frame_annotation.image.size
K = np.eye(3)
s = (min(h, w) - 1) / 2
if frame_annotation.viewpoint.intrinsics_format == "ndc_norm_image_bounds":
K[0, 0] = f[0] * (w - 1) / 2
K[1, 1] = f[1] * (h - 1) / 2
elif frame_annotation.viewpoint.intrinsics_format == "ndc_isotropic":
K[0, 0] = f[0] * s / 2
K[1, 1] = f[1] * s / 2
else:
assert (
False
), f"Invalid intrinsics_format: {frame_annotation.viewpoint.intrinsics_format}"
K[0, 2] = -p[0] * s + (w - 1) / 2
K[1, 2] = -p[1] * s + (h - 1) / 2
R = np.array(frame_annotation.viewpoint.R).T # note the transpose here
T = np.array(frame_annotation.viewpoint.T)
pose = np.concatenate([R, T[:, None]], 1)
# Need to be converted into OpenGL format. Flip the direction of x, z axis
pose = np.diag([-1, 1, -1]).astype(np.float32) @ pose
return pose, K
def normalize_w2c_camera_pose_sequence(
target_camera_poses,
condition_camera_poses=None,
output_c2w=False,
translation_norm_mode="div_by_max",
):
"""
Normalize camera pose sequence so that the first frame is identity rotation and zero translation,
and the translation scale is normalized by the farest point from the first frame (to one).
:param target_camera_poses: W2C poses tensor in [N, 3, 4]
:param condition_camera_poses: W2C poses tensor in [N, 3, 4]
:return: Tuple(Tensor, Tensor), the normalized `target_camera_poses` and `condition_camera_poses`
"""
# Normalize at w2c, all poses should be in w2c in UnifiedFrame
num_target_views = target_camera_poses.size(0)
if condition_camera_poses is not None:
all_poses = torch.concat([target_camera_poses, condition_camera_poses], dim=0)
else:
all_poses = target_camera_poses
# Convert W2C to C2W
normalized_poses = transform_pose_sequence_to_relative_c2w(
convert_w2c_between_c2w(all_poses)
)
# Here normalized_poses is C2W
if not output_c2w:
# Convert from C2W back to W2C if output_c2w is False.
normalized_poses = convert_w2c_between_c2w(normalized_poses)
t_norms = torch.linalg.norm(normalized_poses[:, :, 3], ord=2, dim=-1)
# print("t_norms=", t_norms)
largest_t_norm = torch.max(t_norms)
# print("largest_t_norm=", largest_t_norm)
# normalized_poses[:, :, 3] -= first_t.unsqueeze(0).repeat(normalized_poses.size(0), 1)
if translation_norm_mode == "div_by_max_plus_one":
# Always add a constant component to the translation norm
largest_t_norm = largest_t_norm + 1.0
elif translation_norm_mode == "div_by_max":
largest_t_norm = largest_t_norm
if largest_t_norm <= 0.05:
largest_t_norm = 0.05
elif translation_norm_mode == "disabled":
largest_t_norm = 1.0
else:
assert False, f"Invalid translation_norm_mode: {translation_norm_mode}."
normalized_poses[:, :, 3] /= largest_t_norm
target_camera_poses = normalized_poses[:num_target_views]
if condition_camera_poses is not None:
condition_camera_poses = normalized_poses[num_target_views:]
else:
condition_camera_poses = None
# print("After First condition:", condition_camera_poses[0])
# print("After First target:", target_camera_poses[0])
return target_camera_poses, condition_camera_poses
def central_crop_pil_image(_image, crop_size, use_central_padding=False):
if use_central_padding:
# Determine the new size
_w, _h = _image.size
new_size = max(_w, _h)
# Create a new image with white background
new_image = Image.new("RGB", (new_size, new_size), (255, 255, 255))
# Calculate the position to paste the original image
paste_position = ((new_size - _w) // 2, (new_size - _h) // 2)
# Paste the original image onto the new image
new_image.paste(_image, paste_position)
_image = new_image
# get the new size again if padded
_w, _h = _image.size
scale = crop_size / min(_h, _w)
# resize shortest side to crop_size
_w_out, _h_out = int(scale * _w), int(scale * _h)
_image = _image.resize(
(_w_out, _h_out),
resample=(
PIL.Image.Resampling.LANCZOS if scale < 1 else PIL.Image.Resampling.BICUBIC
),
)
# center crop
margin_w = (_image.size[0] - crop_size) // 2
margin_h = (_image.size[1] - crop_size) // 2
_image = _image.crop(
(margin_w, margin_h, margin_w + crop_size, margin_h + crop_size)
)
return _image
def crop_and_resize(
image: Image.Image, target_width: int, target_height: int
) -> Image.Image:
"""
Crops and resizes an image while preserving the aspect ratio.
Args:
image (Image.Image): Input PIL image to be cropped and resized.
target_width (int): Target width of the output image.
target_height (int): Target height of the output image.
Returns:
Image.Image: Cropped and resized image.
"""
# Original dimensions
original_width, original_height = image.size
original_aspect = original_width / original_height
target_aspect = target_width / target_height
# Calculate crop box to maintain aspect ratio
if original_aspect > target_aspect:
# Crop horizontally
new_width = int(original_height * target_aspect)
new_height = original_height
left = (original_width - new_width) / 2
top = 0
right = left + new_width
bottom = original_height
else:
# Crop vertically
new_width = original_width
new_height = int(original_width / target_aspect)
left = 0
top = (original_height - new_height) / 2
right = original_width
bottom = top + new_height
# Crop and resize
cropped_image = image.crop((left, top, right, bottom))
resized_image = cropped_image.resize((target_width, target_height), Image.LANCZOS)
return resized_image
def calculate_fov_after_resize(
fov_x: float,
fov_y: float,
original_width: int,
original_height: int,
target_width: int,
target_height: int,
) -> (float, float):
"""
Calculates the new field of view after cropping and resizing an image.
Args:
fov_x (float): Original field of view in the x-direction (horizontal).
fov_y (float): Original field of view in the y-direction (vertical).
original_width (int): Original width of the image.
original_height (int): Original height of the image.
target_width (int): Target width of the output image.
target_height (int): Target height of the output image.
Returns:
(float, float): New field of view (fov_x, fov_y) after cropping and resizing.
"""
original_aspect = original_width / original_height
target_aspect = target_width / target_height
if original_aspect > target_aspect:
# Crop horizontally
new_width = int(original_height * target_aspect)
new_fov_x = fov_x * (new_width / original_width)
new_fov_y = fov_y
else:
# Crop vertically
new_height = int(original_width / target_aspect)
new_fov_y = fov_y * (new_height / original_height)
new_fov_x = fov_x
return new_fov_x, new_fov_y
import copy
import random
from PIL import Image
import numpy as np
def create_relative(RT_list, K_1=4.7, dataset="syn"):
if dataset == "realestate":
scale_T = 1
RT_list = [RT.reshape(3, 4) for RT in RT_list]
elif dataset == "syn":
scale_T = (470 / K_1) / 7.5
"""
4.694746736956946052e+02 0.000000000000000000e+00 4.800000000000000000e+02
0.000000000000000000e+00 4.694746736956946052e+02 2.700000000000000000e+02
0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00
"""
elif dataset == "zero123":
scale_T = 0.5
else:
raise Exception("invalid dataset type")
# convert x y z to x -y -z
if dataset == "zero123":
flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
for i in range(len(RT_list)):
RT_list[i] = np.dot(flip_matrix, RT_list[i])
temp = []
first_frame_RT = copy.deepcopy(RT_list[0])
# first_frame_R_inv = np.linalg.inv(first_frame_RT[:,:3])
first_frame_R_inv = first_frame_RT[:, :3].T
first_frame_T = first_frame_RT[:, -1]
for RT in RT_list:
RT[:, :3] = np.dot(RT[:, :3], first_frame_R_inv)
RT[:, -1] = RT[:, -1] - np.dot(RT[:, :3], first_frame_T)
RT[:, -1] = RT[:, -1] * scale_T
temp.append(RT)
RT_list = temp
if dataset == "realestate":
RT_list = [RT.reshape(-1) for RT in RT_list]
return RT_list
def sigma_matrix2(sig_x, sig_y, theta):
"""Calculate the rotated sigma matrix (two dimensional matrix).
Args:
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
Returns:
ndarray: Rotated sigma matrix.
"""
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
u_matrix = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
def mesh_grid(kernel_size):
"""Generate the mesh grid, centering at zero.
Args:
kernel_size (int):
Returns:
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
xx (ndarray): with the shape (kernel_size, kernel_size)
yy (ndarray): with the shape (kernel_size, kernel_size)
"""
ax = np.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0)
xx, yy = np.meshgrid(ax, ax)
xy = np.hstack(
(
xx.reshape((kernel_size * kernel_size, 1)),
yy.reshape(kernel_size * kernel_size, 1),
)
).reshape(kernel_size, kernel_size, 2)
return xy, xx, yy
def pdf2(sigma_matrix, grid):
"""Calculate PDF of the bivariate Gaussian distribution.
Args:
sigma_matrix (ndarray): with the shape (2, 2)
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
kernel (ndarrray): un-normalized kernel.
"""
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
return kernel
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
isotropic (bool):
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
kernel = pdf2(sigma_matrix, grid)
kernel = kernel / np.sum(kernel)
return kernel
def rgba_to_rgb_with_bg(rgba_image, bg_color=(255, 255, 255)):
"""
Convert a PIL RGBA Image to an RGB Image with a white background.
Args:
rgba_image (Image): A PIL Image object in RGBA mode.
Returns:
Image: A PIL Image object in RGB mode with white background.
"""
# Ensure the image is in RGBA mode
# Ensure the image is in RGBA mode
if rgba_image.mode != "RGBA":
return rgba_image
# raise ValueError("The image must be in RGBA mode")
# Create a white background image
white_bg_rgb = Image.new("RGB", rgba_image.size, bg_color)
# Paste the RGBA image onto the white background using alpha channel as mask
white_bg_rgb.paste(
rgba_image, mask=rgba_image.split()[3]
) # 3 is the alpha channel index
return white_bg_rgb
def random_order_preserving_selection(items, num):
if num > len(items):
print("WARNING: Item list is shorter than `num` given.")
return items
selected_indices = sorted(random.sample(range(len(items)), num))
selected_items = [items[i] for i in selected_indices]
return selected_items
def pad_pil_image_to_square(image, fill_color=(255, 255, 255)):
"""
Pad an image to make it square with the given fill color.
Args:
image (PIL.Image): The original image.
fill_color (tuple): The color to use for padding (default is black).
Returns:
PIL.Image: A new image that is padded to be square.
"""
width, height = image.size
# Determine the new size, which will be the maximum of width or height
new_size = max(width, height)
# Create a new image with the new size and fill color
new_image = Image.new("RGB", (new_size, new_size), fill_color)
# Calculate the position to paste the original image onto the new image
# This calculation centers the original image in the new square canvas
left = (new_size - width) // 2
top = (new_size - height) // 2
# Paste the original image into the new image
new_image.paste(image, (left, top))
return new_image
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(
device=self.parameters.device
)
def sample(self, noise=None):
if noise is None:
noise = torch.randn(self.mean.shape)
x = self.mean + self.std * noise.to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
self.m_name2s_name = {}
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
"num_updates",
(
torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int)
),
)
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace(".", "")
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
from core.losses.contperceptual import LPIPSWithDiscriminator
\ No newline at end of file
import torch
import torch.nn as nn
from einops import rearrange
from taming.modules.losses.vqperceptual import *
class LPIPSWithDiscriminator(nn.Module):
def __init__(
self,
disc_start,
logvar_init=0.0,
kl_weight=1.0,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_loss="hinge",
max_bs=None,
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
self.kl_weight = kl_weight
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
self.max_bs = max_bs
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
inputs,
reconstructions,
posteriors,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split="train",
weights=None,
):
if inputs.dim() == 5:
inputs = rearrange(inputs, "b c t h w -> (b t) c h w")
if reconstructions.dim() == 5:
reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w")
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0:
if self.max_bs is not None and self.max_bs < inputs.shape[0]:
input_list = torch.split(inputs, self.max_bs, dim=0)
reconstruction_list = torch.split(reconstructions, self.max_bs, dim=0)
p_losses = [
self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
for inputs, reconstructions in zip(input_list, reconstruction_list)
]
p_loss = torch.cat(p_losses, dim=0)
else:
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
# now the GAN part
if optimizer_idx == 0:
# generator update
if cond is None:
assert not self.disc_conditional
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake)
if self.disc_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
loss = (
weighted_nll_loss
+ self.kl_weight * kl_loss
+ d_weight * disc_factor * g_loss
)
log = {
"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/logvar".format(split): self.logvar.detach(),
"{}/kl_loss".format(split): kl_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
else:
logits_real = self.discriminator(
torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean(),
}
return d_loss, log
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])
loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])
loss_real = (weights * loss_real).sum() / weights.sum()
loss_fake = (weights * loss_fake).sum() / weights.sum()
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
def measure_perplexity(predicted_indices, n_embed):
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
def l1(x, y):
return torch.abs(x - y)
def l2(x, y):
return torch.pow((x - y), 2)
class VQLPIPSWithDiscriminator(nn.Module):
def __init__(
self,
disc_start,
codebook_weight=1.0,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_ndf=64,
disc_loss="hinge",
n_classes=None,
perceptual_loss="lpips",
pixel_loss="l1",
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
assert perceptual_loss in ["lpips", "clips", "dists"]
assert pixel_loss in ["l1", "l2"]
self.codebook_weight = codebook_weight
self.pixel_weight = pixelloss_weight
if perceptual_loss == "lpips":
print(f"{self.__class__.__name__}: Running with LPIPS.")
self.perceptual_loss = LPIPS().eval()
else:
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
self.perceptual_weight = perceptual_weight
if pixel_loss == "l1":
self.pixel_loss = l1
else:
self.pixel_loss = l2
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
ndf=disc_ndf,
).apply(weights_init)
self.discriminator_iter_start = disc_start
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
self.n_classes = n_classes
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
codebook_loss,
inputs,
reconstructions,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split="train",
predicted_indices=None,
):
if not exists(codebook_loss):
codebook_loss = torch.tensor([0.0]).to(inputs.device)
# rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
else:
p_loss = torch.tensor([0.0])
nll_loss = rec_loss
# nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss = torch.mean(nll_loss)
# now the GAN part
if optimizer_idx == 0:
# generator update
if cond is None:
assert not self.disc_conditional
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake)
try:
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
loss = (
nll_loss
+ d_weight * disc_factor * g_loss
+ self.codebook_weight * codebook_loss.mean()
)
log = {
"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/p_loss".format(split): p_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
if predicted_indices is not None:
assert self.n_classes is not None
with torch.no_grad():
perplexity, cluster_usage = measure_perplexity(
predicted_indices, self.n_classes
)
log[f"{split}/perplexity"] = perplexity
log[f"{split}/cluster_usage"] = cluster_usage
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
else:
logits_real = self.discriminator(
torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean(),
}
return d_loss, log
import os
import json
from contextlib import contextmanager
import torch
import numpy as np
from einops import rearrange
import torch.nn.functional as F
import torch.distributed as dist
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from taming.modules.vqvae.quantize import VectorQuantizer as VectorQuantizer
from core.modules.networks.ae_modules import Encoder, Decoder
from core.distributions import DiagonalGaussianDistribution
from utils.utils import instantiate_from_config
from utils.save_video import tensor2videogrids
from core.common import shape_to_str, gather_data
class AutoencoderKL(pl.LightningModule):
def __init__(
self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
test=False,
logdir=None,
input_dim=4,
test_args=None,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
self.input_dim = input_dim
self.test = test
self.test_args = test_args
self.logdir = logdir
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
if self.test:
self.init_test()
def init_test(
self,
):
self.test = True
save_dir = os.path.join(self.logdir, "test")
if "ckpt" in self.test_args:
ckpt_name = (
os.path.basename(self.test_args.ckpt).split(".ckpt")[0]
+ f"_epoch{self._cur_epoch}"
)
self.root = os.path.join(save_dir, ckpt_name)
else:
self.root = save_dir
if "test_subdir" in self.test_args:
self.root = os.path.join(save_dir, self.test_args.test_subdir)
self.root_zs = os.path.join(self.root, "zs")
self.root_dec = os.path.join(self.root, "reconstructions")
self.root_inputs = os.path.join(self.root, "inputs")
os.makedirs(self.root, exist_ok=True)
if self.test_args.save_z:
os.makedirs(self.root_zs, exist_ok=True)
if self.test_args.save_reconstruction:
os.makedirs(self.root_dec, exist_ok=True)
if self.test_args.save_input:
os.makedirs(self.root_inputs, exist_ok=True)
assert self.test_args is not None
self.test_maximum = getattr(
self.test_args, "test_maximum", None
) # 1500 # 12000/8
self.count = 0
self.eval_metrics = {}
self.decodes = []
self.save_decode_samples = 2048
if getattr(self.test_args, "cal_metrics", False):
self.EvalLpips = EvalLpips()
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")
try:
self._cur_epoch = sd["epoch"]
sd = sd["state_dict"]
except:
self._cur_epoch = "null"
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
# self.load_state_dict(sd, strict=True)
print(f"Restored from {path}")
def encode(self, x, **kwargs):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z, **kwargs):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
# if len(x.shape) == 3:
# x = x[..., None]
# if x.dim() == 4:
# x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
if x.dim() == 5 and self.input_dim == 4:
b, c, t, h, w = x.shape
self.b = b
self.t = t
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"discloss",
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val",
)
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val",
)
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def test_step(self, batch, batch_idx):
# save z, dec
inputs = self.get_input(batch, self.image_key)
# forward
sample_posterior = True
posterior = self.encode(inputs)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
# logs
if self.test_args.save_z:
torch.save(
z,
os.path.join(
self.root_zs,
f"zs_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(z)}.pt",
),
)
if self.test_args.save_reconstruction:
tensor2videogrids(
dec,
self.root_dec,
f"reconstructions_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(z)}.mp4",
fps=10,
)
if self.test_args.save_input:
tensor2videogrids(
inputs,
self.root_inputs,
f"inputs_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(z)}.mp4",
fps=10,
)
if "save_z" in self.test_args and self.test_args.save_z:
dec_np = (dec.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) + 1) / 2 * 255
dec_np = dec_np.astype(np.uint8)
self.root_dec_np = os.path.join(self.root, "reconstructions_np")
os.makedirs(self.root_dec_np, exist_ok=True)
np.savez(
os.path.join(
self.root_dec_np,
f"reconstructions_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(dec_np)}.npz",
),
dec_np,
)
self.count += z.shape[0]
# misc
self.log("batch_idx", batch_idx, prog_bar=True)
self.log_dict(self.eval_metrics, prog_bar=True, logger=True)
torch.cuda.empty_cache()
if self.test_maximum is not None:
if self.count > self.test_maximum:
import sys
sys.exit()
else:
prog = self.count / self.test_maximum * 100
print(f"Test progress: {prog:.2f}% [{self.count}/{self.test_maximum}]")
@rank_zero_only
def on_test_end(self):
if self.test_args.cal_metrics:
psnrs, ssims, ms_ssims, lpipses = [], [], [], []
n_batches = 0
n_samples = 0
overall = {}
for k, v in self.eval_metrics.items():
psnrs.append(v["psnr"])
ssims.append(v["ssim"])
lpipses.append(v["lpips"])
n_batches += 1
n_samples += v["n_samples"]
mean_psnr = sum(psnrs) / len(psnrs)
mean_ssim = sum(ssims) / len(ssims)
# overall['ms_ssim'] = min(ms_ssims)
mean_lpips = sum(lpipses) / len(lpipses)
overall = {
"psnr": mean_psnr,
"ssim": mean_ssim,
"lpips": mean_lpips,
"n_batches": n_batches,
"n_samples": n_samples,
}
overall_t = torch.tensor([mean_psnr, mean_ssim, mean_lpips])
# dump
for k, v in overall.items():
if isinstance(v, torch.Tensor):
overall[k] = float(v)
with open(
os.path.join(self.root, f"reconstruction_metrics.json"), "w"
) as f:
json.dump(overall, f)
f.close()
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr,
betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
)
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x
import logging
from collections import OrderedDict
from contextlib import contextmanager
from functools import partial
import numpy as np
from einops import rearrange
from tqdm import tqdm
import torch
import torch.nn as nn
from torchvision.utils import make_grid
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from core.modules.networks.unet_modules import TASK_IDX_IMAGE, TASK_IDX_RAY
from utils.utils import instantiate_from_config
from core.ema import LitEma
from core.distributions import DiagonalGaussianDistribution
from core.models.utils_diffusion import make_beta_schedule, rescale_zero_terminal_snr
from core.models.samplers.ddim import DDIMSampler
from core.basics import disabled_train
from core.common import extract_into_tensor, noise_like, exists, default
main_logger = logging.getLogger("main_logger")
class BD(nn.Module):
def __init__(self, G=10):
super(BD, self).__init__()
self.momentum = 0.9
self.register_buffer("running_wm", torch.eye(G).expand(G, G))
self.running_wm = None
def forward(self, x, T=5, eps=1e-5):
N, C, G, H, W = x.size()
x = torch.permute(x, [0, 2, 1, 3, 4])
x_in = x.transpose(0, 1).contiguous().view(G, -1)
if self.training:
mean = x_in.mean(-1, keepdim=True)
xc = x_in - mean
d, m = x_in.size()
P = [None] * (T + 1)
P[0] = torch.eye(G, device=x.device)
Sigma = (torch.matmul(xc, xc.transpose(0, 1))) / float(m) + P[0] * eps
rTr = (Sigma * P[0]).sum([0, 1], keepdim=True).reciprocal()
Sigma_N = Sigma * rTr
wm = torch.linalg.solve_triangular(
torch.linalg.cholesky(Sigma_N), P[0], upper=False
)
self.running_wm = self.momentum * self.running_wm + (1 - self.momentum) * wm
else:
wm = self.running_wm
x_out = wm @ x_in
x_out = x_out.view(G, N, C, H, W).permute([1, 2, 0, 3, 4]).contiguous()
return x_out
class AbstractDDPM(pl.LightningModule):
def __init__(
self,
unet_config,
time_steps=1000,
beta_schedule="linear",
loss_type="l2",
monitor=None,
use_ema=True,
first_stage_key="image",
image_size=256,
channels=3,
log_every_t=100,
clip_denoised=True,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
given_betas=None,
original_elbo_weight=0.0,
# weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
v_posterior=0.0,
l_simple_weight=1.0,
conditioning_key=None,
parameterization="eps",
rescale_betas_zero_snr=False,
scheduler_config=None,
use_positional_encodings=False,
learn_logvar=False,
logvar_init=0.0,
bd_noise=False,
):
super().__init__()
assert parameterization in [
"eps",
"x0",
"v",
], 'currently only supporting "eps" and "x0" and "v"'
self.parameterization = parameterization
main_logger.info(
f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
)
self.cond_stage_model = None
self.clip_denoised = clip_denoised
self.log_every_t = log_every_t
self.first_stage_key = first_stage_key
self.channels = channels
self.cond_channels = unet_config.params.in_channels - channels
self.temporal_length = unet_config.params.temporal_length
self.image_size = image_size
self.bd_noise = bd_noise
if self.bd_noise:
self.bd = BD(G=self.temporal_length)
if isinstance(self.image_size, int):
self.image_size = [self.image_size, self.image_size]
self.use_positional_encodings = use_positional_encodings
self.model = DiffusionWrapper(unet_config)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
main_logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.rescale_betas_zero_snr = rescale_betas_zero_snr
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
self.scheduler_config = scheduler_config
self.v_posterior = v_posterior
self.original_elbo_weight = original_elbo_weight
self.l_simple_weight = l_simple_weight
self.linear_end = None
self.linear_start = None
self.num_time_steps: int = 1000
if monitor is not None:
self.monitor = monitor
self.register_schedule(
given_betas=given_betas,
beta_schedule=beta_schedule,
time_steps=time_steps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
self.given_betas = given_betas
self.beta_schedule = beta_schedule
self.time_steps = time_steps
self.cosine_s = cosine_s
self.loss_type = loss_type
self.learn_logvar = learn_logvar
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_time_steps,))
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
* noise
)
def predict_start_from_z_and_v(self, x_t, t, v):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def predict_eps_from_z_and_v(self, x_t, t, v):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
* x_t
)
def get_v(self, x, noise, t):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
)
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
main_logger.info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
main_logger.info(f"{context}: Restored training weights")
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract_into_tensor(
self.log_one_minus_alphas_cumprod, t, x_start.shape
)
return mean, variance, log_variance
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def get_loss(self, pred, target, mean=True):
if self.loss_type == "l1":
loss = (target - pred).abs()
if mean:
loss = loss.mean()
elif self.loss_type == "l2":
if mean:
loss = torch.nn.functional.mse_loss(target, pred)
else:
loss = torch.nn.functional.mse_loss(target, pred, reduction="none")
else:
raise NotImplementedError("unknown loss type '{loss_type}'")
return loss
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self.model)
def _get_rows_from_list(self, samples):
n_imgs_per_row = len(samples)
denoise_grid = rearrange(samples, "n b c h w -> b n c h w")
denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
class DualStreamMultiViewDiffusionModel(AbstractDDPM):
def __init__(
self,
first_stage_config,
data_key_images,
data_key_rays,
data_key_text_condition=None,
ckpt_path=None,
cond_stage_config=None,
num_time_steps_cond=None,
cond_stage_trainable=False,
cond_stage_forward=None,
conditioning_key=None,
uncond_prob=0.2,
uncond_type="empty_seq",
scale_factor=1.0,
scale_by_std=False,
use_noise_offset=False,
use_dynamic_rescale=False,
base_scale=0.3,
turning_step=400,
per_frame_auto_encoding=False,
# added for LVDM
encoder_type="2d",
cond_frames=None,
logdir=None,
empty_params_only=False,
# Image Condition
cond_img_config=None,
image_proj_model_config=None,
random_cond=False,
padding=False,
cond_concat=False,
frame_mask=False,
use_camera_pose_query_transformer=False,
with_cond_binary_mask=False,
apply_condition_mask_in_training_loss=True,
separate_noise_and_condition=False,
condition_padding_with_anchor=False,
ray_as_image=False,
use_task_embedding=False,
use_ray_decoder_loss_high_frequency_isolation=False,
disable_ray_stream=False,
ray_loss_weight=1.0,
train_with_multi_view_feature_alignment=False,
use_text_cross_attention_condition=True,
*args,
**kwargs,
):
self.image_proj_model = None
self.apply_condition_mask_in_training_loss = (
apply_condition_mask_in_training_loss
)
self.separate_noise_and_condition = separate_noise_and_condition
self.condition_padding_with_anchor = condition_padding_with_anchor
self.use_text_cross_attention_condition = use_text_cross_attention_condition
self.data_key_images = data_key_images
self.data_key_rays = data_key_rays
self.data_key_text_condition = data_key_text_condition
self.num_time_steps_cond = default(num_time_steps_cond, 1)
self.scale_by_std = scale_by_std
assert self.num_time_steps_cond <= kwargs["time_steps"]
self.shorten_cond_schedule = self.num_time_steps_cond > 1
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
self.cond_stage_trainable = cond_stage_trainable
self.empty_params_only = empty_params_only
self.per_frame_auto_encoding = per_frame_auto_encoding
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
else:
self.register_buffer("scale_factor", torch.tensor(scale_factor))
self.use_noise_offset = use_noise_offset
self.use_dynamic_rescale = use_dynamic_rescale
if use_dynamic_rescale:
scale_arr1 = np.linspace(1.0, base_scale, turning_step)
scale_arr2 = np.full(self.num_time_steps, base_scale)
scale_arr = np.concatenate((scale_arr1, scale_arr2))
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer("scale_arr", to_torch(scale_arr))
self.instantiate_first_stage(first_stage_config)
if self.use_text_cross_attention_condition and cond_stage_config is not None:
self.instantiate_cond_stage(cond_stage_config)
self.first_stage_config = first_stage_config
self.cond_stage_config = cond_stage_config
self.clip_denoised = False
self.cond_stage_forward = cond_stage_forward
self.encoder_type = encoder_type
assert encoder_type in ["2d", "3d"]
self.uncond_prob = uncond_prob
self.classifier_free_guidance = True if uncond_prob > 0 else False
assert uncond_type in ["zero_embed", "empty_seq"]
self.uncond_type = uncond_type
if cond_frames is not None:
frame_len = self.temporal_length
assert cond_frames[-1] < frame_len, main_logger.info(
f"Error: conditioning frame index must not be greater than {frame_len}!"
)
cond_mask = torch.zeros(frame_len, dtype=torch.float32)
cond_mask[cond_frames] = 1.0
self.cond_mask = cond_mask[None, None, :, None, None]
else:
self.cond_mask = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path)
self.restarted_from_ckpt = True
self.logdir = logdir
self.with_cond_binary_mask = with_cond_binary_mask
self.random_cond = random_cond
self.padding = padding
self.cond_concat = cond_concat
self.frame_mask = frame_mask
self.use_img_context = True if cond_img_config is not None else False
self.use_camera_pose_query_transformer = use_camera_pose_query_transformer
if self.use_img_context:
self.init_img_embedder(cond_img_config, freeze=True)
self.init_projector(image_proj_model_config, trainable=True)
self.ray_as_image = ray_as_image
self.use_task_embedding = use_task_embedding
self.use_ray_decoder_loss_high_frequency_isolation = (
use_ray_decoder_loss_high_frequency_isolation
)
self.disable_ray_stream = disable_ray_stream
if disable_ray_stream:
assert (
not ray_as_image
and not self.model.diffusion_model.use_ray_decoder
and not self.model.diffusion_model.use_ray_decoder_residual
), "Options related to ray decoder should not be enabled when disabling ray stream."
assert (
not use_task_embedding
and not self.model.diffusion_model.use_task_embedding
), "Task embedding should not be enabled when disabling ray stream."
assert (
not self.model.diffusion_model.use_addition_ray_output_head
), "Additional ray output head should not be enabled when disabling ray stream."
assert (
not self.model.diffusion_model.use_lora_for_rays_in_output_blocks
), "LoRA for rays should not be enabled when disabling ray stream."
self.ray_loss_weight = ray_loss_weight
self.train_with_multi_view_feature_alignment = False
if train_with_multi_view_feature_alignment:
print(f"MultiViewFeatureExtractor is ignored during inference.")
def init_from_ckpt(self, checkpoint_path):
main_logger.info(f"Initializing model from checkpoint {checkpoint_path}...")
def grab_ipa_weight(state_dict):
ipa_state_dict = OrderedDict()
for n in list(state_dict.keys()):
if "to_k_ip" in n or "to_v_ip" in n:
ipa_state_dict[n] = state_dict[n]
elif "image_proj_model" in n:
if (
self.use_camera_pose_query_transformer
and "image_proj_model.latents" in n
):
ipa_state_dict[n] = torch.cat(
[state_dict[n] for i in range(16)], dim=1
)
else:
ipa_state_dict[n] = state_dict[n]
return ipa_state_dict
state_dict = torch.load(checkpoint_path, map_location="cpu")
if "module" in state_dict.keys():
# deepspeed
target_state_dict = OrderedDict()
for key in state_dict["module"].keys():
target_state_dict[key[16:]] = state_dict["module"][key]
elif "state_dict" in list(state_dict.keys()):
target_state_dict = state_dict["state_dict"]
else:
raise KeyError("Weight key is not found in the state dict.")
ipa_state_dict = grab_ipa_weight(target_state_dict)
self.load_state_dict(ipa_state_dict, strict=False)
main_logger.info("Checkpoint loaded.")
def init_img_embedder(self, config, freeze=True):
embedder = instantiate_from_config(config)
if freeze:
self.embedder = embedder.eval()
self.embedder.train = disabled_train
for param in self.embedder.parameters():
param.requires_grad = False
def make_cond_schedule(
self,
):
self.cond_ids = torch.full(
size=(self.num_time_steps,),
fill_value=self.num_time_steps - 1,
dtype=torch.long,
)
ids = torch.round(
torch.linspace(0, self.num_time_steps - 1, self.num_time_steps_cond)
).long()
self.cond_ids[: self.num_time_steps_cond] = ids
def init_projector(self, config, trainable):
self.image_proj_model = instantiate_from_config(config)
if not trainable:
self.image_proj_model.eval()
self.image_proj_model.train = disabled_train
for param in self.image_proj_model.parameters():
param.requires_grad = False
@staticmethod
def pad_cond_images(batch_images):
h, w = batch_images.shape[-2:]
border = (w - h) // 2
# use padding at (W_t,W_b,H_t,H_b)
batch_images = torch.nn.functional.pad(
batch_images, (0, 0, border, border), "constant", 0
)
return batch_images
# Never delete this func: it is used in log_images() and inference stage
def get_image_embeds(self, batch_images, batch=None):
# input shape: b c h w
if self.padding:
batch_images = self.pad_cond_images(batch_images)
img_token = self.embedder(batch_images)
if self.use_camera_pose_query_transformer:
batch_size, num_views, _ = batch["target_poses"].shape
img_emb = self.image_proj_model(
img_token, batch["target_poses"].reshape(batch_size, num_views, 12)
)
else:
img_emb = self.image_proj_model(img_token)
return img_emb
@staticmethod
def get_input(batch, k):
x = batch[k]
"""
# for image batch from image loader
if len(x.shape) == 4:
x = rearrange(x, 'b h w c -> b c h w')
"""
x = x.to(memory_format=torch.contiguous_format) # .float()
return x
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
# only for very first batch, reset the self.scale_factor
if (
self.scale_by_std
and self.current_epoch == 0
and self.global_step == 0
and batch_idx == 0
and not self.restarted_from_ckpt
):
assert (
self.scale_factor == 1.0
), "rather not use custom rescaling and std-rescaling simultaneously"
# set rescale weight to 1./std of encodings
main_logger.info("## USING STD-RESCALING ###")
x = self.get_input(batch, self.first_stage_key)
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
del self.scale_factor
self.register_buffer("scale_factor", 1.0 / z.flatten().std())
main_logger.info(f"setting self.scale_factor to {self.scale_factor}")
main_logger.info("## USING STD-RESCALING ###")
main_logger.info(f"std={z.flatten().std()}")
def register_schedule(
self,
given_betas=None,
beta_schedule="linear",
time_steps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(
beta_schedule,
time_steps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
if self.rescale_betas_zero_snr:
betas = rescale_zero_terminal_snr(betas)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
(time_steps,) = betas.shape
self.num_time_steps = int(time_steps)
self.linear_start = linear_start
self.linear_end = linear_end
assert (
alphas_cumprod.shape[0] == self.num_time_steps
), "alphas have to be defined for each timestep"
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod",
to_torch(np.sqrt(1.0 / (alphas_cumprod + 1e-5))),
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / (alphas_cumprod + 1e-5) - 1)),
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (
1.0 - alphas_cumprod_prev
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer("posterior_variance", to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer(
"posterior_log_variance_clipped",
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
)
self.register_buffer(
"posterior_mean_coef1",
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
)
self.register_buffer(
"posterior_mean_coef2",
to_torch(
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
),
)
if self.parameterization == "eps":
lvlb_weights = self.betas**2 / (
2
* self.posterior_variance
* to_torch(alphas)
* (1 - self.alphas_cumprod)
)
elif self.parameterization == "x0":
lvlb_weights = (
0.5
* np.sqrt(torch.Tensor(alphas_cumprod))
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
)
elif self.parameterization == "v":
lvlb_weights = torch.ones_like(
self.betas**2
/ (
2
* self.posterior_variance
* to_torch(alphas)
* (1 - self.alphas_cumprod)
)
)
else:
raise NotImplementedError("mu not supported")
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()
if self.shorten_cond_schedule:
self.make_cond_schedule()
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
def instantiate_cond_stage(self, config):
if not self.cond_stage_trainable:
model = instantiate_from_config(config)
self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
else:
model = instantiate_from_config(config)
self.cond_stage_model = model
def get_learned_conditioning(self, c):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, "encode") and callable(
self.cond_stage_model.encode
):
c = self.cond_stage_model.encode(c)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
def get_first_stage_encoding(self, encoder_posterior, noise=None):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample(noise=noise)
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
)
return self.scale_factor * z
@torch.no_grad()
def encode_first_stage(self, x):
assert x.dim() == 5 or x.dim() == 4, (
"Images should be a either 5-dimensional (batched image sequence) "
"or 4-dimensional (batched images)."
)
if (
self.encoder_type == "2d"
and x.dim() == 5
and not self.per_frame_auto_encoding
):
b, t, _, _, _ = x.shape
x = rearrange(x, "b t c h w -> (b t) c h w")
reshape_back = True
else:
b, _, _, _, _ = x.shape
t = 1
reshape_back = False
if not self.per_frame_auto_encoding:
encoder_posterior = self.first_stage_model.encode(x)
results = self.get_first_stage_encoding(encoder_posterior).detach()
else:
results = []
for index in range(x.shape[1]):
frame_batch = self.first_stage_model.encode(x[:, index, :, :, :])
frame_result = self.get_first_stage_encoding(frame_batch).detach()
results.append(frame_result)
results = torch.stack(results, dim=1)
if reshape_back:
results = rearrange(results, "(b t) c h w -> b t c h w", b=b, t=t)
return results
def decode_core(self, z, **kwargs):
assert z.dim() == 5 or z.dim() == 4, (
"Latents should be a either 5-dimensional (batched latent sequence) "
"or 4-dimensional (batched latents)."
)
if (
self.encoder_type == "2d"
and z.dim() == 5
and not self.per_frame_auto_encoding
):
b, t, _, _, _ = z.shape
z = rearrange(z, "b t c h w -> (b t) c h w")
reshape_back = True
else:
b, _, _, _, _ = z.shape
t = 1
reshape_back = False
if not self.per_frame_auto_encoding:
z = 1.0 / self.scale_factor * z
results = self.first_stage_model.decode(z, **kwargs)
else:
results = []
for index in range(z.shape[1]):
frame_z = 1.0 / self.scale_factor * z[:, index, :, :, :]
frame_result = self.first_stage_model.decode(frame_z, **kwargs)
results.append(frame_result)
results = torch.stack(results, dim=1)
if reshape_back:
results = rearrange(results, "(b t) c h w -> b t c h w", b=b, t=t)
return results
@torch.no_grad()
def decode_first_stage(self, z, **kwargs):
return self.decode_core(z, **kwargs)
def differentiable_decode_first_stage(self, z, **kwargs):
return self.decode_core(z, **kwargs)
def get_batch_input(
self,
batch,
random_drop_training_conditions,
return_reconstructed_target_images=False,
):
combined_images = batch[self.data_key_images]
clean_combined_image_latents = self.encode_first_stage(combined_images)
mask_preserving_target = batch["mask_preserving_target"].reshape(
batch["mask_preserving_target"].size(0),
batch["mask_preserving_target"].size(1),
1,
1,
1,
)
mask_preserving_condition = 1.0 - mask_preserving_target
if self.ray_as_image:
clean_combined_ray_images = batch[self.data_key_rays]
clean_combined_ray_o_latents = self.encode_first_stage(
clean_combined_ray_images[:, :, :3, :, :]
)
clean_combined_ray_d_latents = self.encode_first_stage(
clean_combined_ray_images[:, :, 3:, :, :]
)
clean_combined_rays = torch.concat(
[clean_combined_ray_o_latents, clean_combined_ray_d_latents], dim=2
)
if self.condition_padding_with_anchor:
condition_ray_images = batch["condition_rays"]
condition_ray_o_images = self.encode_first_stage(
condition_ray_images[:, :, :3, :, :]
)
condition_ray_d_images = self.encode_first_stage(
condition_ray_images[:, :, 3:, :, :]
)
condition_rays = torch.concat(
[condition_ray_o_images, condition_ray_d_images], dim=2
)
else:
condition_rays = clean_combined_rays * mask_preserving_target
else:
clean_combined_rays = batch[self.data_key_rays]
if self.condition_padding_with_anchor:
condition_rays = batch["condition_rays"]
else:
condition_rays = clean_combined_rays * mask_preserving_target
if self.condition_padding_with_anchor:
condition_images_latents = self.encode_first_stage(
batch["condition_images"]
)
else:
condition_images_latents = (
clean_combined_image_latents * mask_preserving_condition
)
if random_drop_training_conditions:
random_num = torch.rand(
combined_images.size(0), device=combined_images.device
)
else:
random_num = torch.ones(
combined_images.size(0), device=combined_images.device
)
text_feature_condition_mask = rearrange(
random_num < 2 * self.uncond_prob, "n -> n 1 1"
)
image_feature_condition_mask = 1 - rearrange(
(random_num >= self.uncond_prob).float()
* (random_num < 3 * self.uncond_prob).float(),
"n -> n 1 1 1 1",
)
ray_condition_mask = 1 - rearrange(
(random_num >= 1.5 * self.uncond_prob).float()
* (random_num < 3.5 * self.uncond_prob).float(),
"n -> n 1 1 1 1",
)
mask_preserving_first_target = batch[
"mask_only_preserving_first_target"
].reshape(
batch["mask_only_preserving_first_target"].size(0),
batch["mask_only_preserving_first_target"].size(1),
1,
1,
1,
)
mask_preserving_first_condition = batch[
"mask_only_preserving_first_condition"
].reshape(
batch["mask_only_preserving_first_condition"].size(0),
batch["mask_only_preserving_first_condition"].size(1),
1,
1,
1,
)
mask_preserving_anchors = (
mask_preserving_first_target + mask_preserving_first_condition
)
mask_randomly_preserving_first_target = torch.where(
ray_condition_mask.repeat(1, mask_preserving_first_target.size(1), 1, 1, 1)
== 1.0,
1.0,
mask_preserving_first_target,
)
mask_randomly_preserving_first_condition = torch.where(
image_feature_condition_mask.repeat(
1, mask_preserving_first_condition.size(1), 1, 1, 1
)
== 1.0,
1.0,
mask_preserving_first_condition,
)
if self.use_text_cross_attention_condition:
text_cond_key = self.data_key_text_condition
text_cond = batch[text_cond_key]
if isinstance(text_cond, dict) or isinstance(text_cond, list):
full_text_cond_emb = self.get_learned_conditioning(text_cond)
else:
full_text_cond_emb = self.get_learned_conditioning(
text_cond.to(self.device)
)
null_text_cond_emb = self.get_learned_conditioning([""])
text_cond_emb = torch.where(
text_feature_condition_mask,
null_text_cond_emb,
full_text_cond_emb.detach(),
)
batch_size, num_views, _, _, _ = batch[self.data_key_images].shape
if self.condition_padding_with_anchor:
condition_images = batch["condition_images"]
else:
condition_images = combined_images * mask_preserving_condition
if random_drop_training_conditions:
condition_image_for_embedder = rearrange(
condition_images * image_feature_condition_mask,
"b t c h w -> (b t) c h w",
)
else:
condition_image_for_embedder = rearrange(
condition_images, "b t c h w -> (b t) c h w"
)
img_token = self.embedder(condition_image_for_embedder)
if self.use_camera_pose_query_transformer:
img_emb = self.image_proj_model(
img_token, batch["target_poses"].reshape(batch_size, num_views, 12)
)
else:
img_emb = self.image_proj_model(img_token)
img_emb = rearrange(
img_emb, "(b t) s d -> b (t s) d", b=batch_size, t=num_views
)
if self.use_text_cross_attention_condition:
c_crossattn = [torch.cat([text_cond_emb, img_emb], dim=1)]
else:
c_crossattn = [img_emb]
cond_dict = {
"c_crossattn": c_crossattn,
"target_camera_poses": batch["target_and_condition_camera_poses"]
* batch["mask_preserving_target"].unsqueeze(-1),
}
if self.disable_ray_stream:
clean_gt = torch.cat([clean_combined_image_latents], dim=2)
else:
clean_gt = torch.cat(
[clean_combined_image_latents, clean_combined_rays], dim=2
)
if random_drop_training_conditions:
combined_condition = torch.cat(
[
condition_images_latents * mask_randomly_preserving_first_condition,
condition_rays * mask_randomly_preserving_first_target,
],
dim=2,
)
else:
combined_condition = torch.cat(
[condition_images_latents, condition_rays], dim=2
)
uncond_combined_condition = torch.cat(
[
condition_images_latents * mask_preserving_anchors,
condition_rays * mask_preserving_anchors,
],
dim=2,
)
mask_full_for_input = torch.cat(
[
mask_preserving_condition.repeat(
1, 1, condition_images_latents.size(2), 1, 1
),
mask_preserving_target.repeat(1, 1, condition_rays.size(2), 1, 1),
],
dim=2,
)
cond_dict.update(
{
"mask_preserving_target": mask_preserving_target,
"mask_preserving_condition": mask_preserving_condition,
"combined_condition": combined_condition,
"uncond_combined_condition": uncond_combined_condition,
"clean_combined_rays": clean_combined_rays,
"mask_full_for_input": mask_full_for_input,
"num_cond_images": rearrange(
batch["num_cond_images"].float(), "b -> b 1 1 1 1"
),
"num_target_images": rearrange(
batch["num_target_images"].float(), "b -> b 1 1 1 1"
),
}
)
out = [clean_gt, cond_dict]
if return_reconstructed_target_images:
target_images_reconstructed = self.decode_first_stage(
clean_combined_image_latents
)
out.append(target_images_reconstructed)
return out
def get_dynamic_scales(self, t, spin_step=400):
base_scale = self.base_scale
scale_t = torch.where(
t < spin_step,
t * (base_scale - 1.0) / spin_step + 1.0,
base_scale * torch.ones_like(t),
)
return scale_t
def forward(self, x, c, **kwargs):
t = torch.randint(
0, self.num_time_steps, (x.shape[0],), device=self.device
).long()
if self.use_dynamic_rescale:
x = x * extract_into_tensor(self.scale_arr, t, x.shape)
return self.p_losses(x, c, t, **kwargs)
def extract_feature(self, batch, t, **kwargs):
z, cond = self.get_batch_input(
batch,
random_drop_training_conditions=False,
return_reconstructed_target_images=False,
)
if self.use_dynamic_rescale:
z = z * extract_into_tensor(self.scale_arr, t, z.shape)
noise = torch.randn_like(z)
if self.use_noise_offset:
noise = noise + 0.1 * torch.randn(
noise.shape[0], noise.shape[1], 1, 1, 1
).to(self.device)
x_noisy = self.q_sample(x_start=z, t=t, noise=noise)
x_noisy = self.process_x_with_condition(x_noisy, condition_dict=cond)
c_crossattn = torch.cat(cond["c_crossattn"], 1)
target_camera_poses = cond["target_camera_poses"]
x_pred, features = self.model(
x_noisy,
t,
context=c_crossattn,
return_output_block_features=True,
camera_poses=target_camera_poses,
**kwargs,
)
return x_pred, features, z
def apply_model(self, x_noisy, t, cond, features_to_return=None, **kwargs):
if not isinstance(cond, dict):
if not isinstance(cond, list):
cond = [cond]
key = (
"c_concat" if self.model.conditioning_key == "concat" else "c_crossattn"
)
cond = {key: cond}
c_crossattn = torch.cat(cond["c_crossattn"], 1)
x_noisy = self.process_x_with_condition(x_noisy, condition_dict=cond)
target_camera_poses = cond["target_camera_poses"]
if self.use_task_embedding:
x_pred_images = self.model(
x_noisy,
t,
context=c_crossattn,
task_idx=TASK_IDX_IMAGE,
camera_poses=target_camera_poses,
**kwargs,
)
x_pred_rays = self.model(
x_noisy,
t,
context=c_crossattn,
task_idx=TASK_IDX_RAY,
camera_poses=target_camera_poses,
**kwargs,
)
x_pred = torch.concat([x_pred_images, x_pred_rays], dim=2)
elif features_to_return is not None:
x_pred, features = self.model(
x_noisy,
t,
context=c_crossattn,
return_input_block_features="input" in features_to_return,
return_middle_feature="middle" in features_to_return,
return_output_block_features="output" in features_to_return,
camera_poses=target_camera_poses,
**kwargs,
)
return x_pred, features
elif self.train_with_multi_view_feature_alignment:
x_pred, aligned_features = self.model(
x_noisy,
t,
context=c_crossattn,
camera_poses=target_camera_poses,
**kwargs,
)
return x_pred, aligned_features
else:
x_pred = self.model(
x_noisy,
t,
context=c_crossattn,
camera_poses=target_camera_poses,
**kwargs,
)
return x_pred
def process_x_with_condition(self, x_noisy, condition_dict):
combined_condition = condition_dict["combined_condition"]
if self.separate_noise_and_condition:
if self.disable_ray_stream:
x_noisy = torch.concat([x_noisy, combined_condition], dim=2)
else:
x_noisy = torch.concat(
[
x_noisy[:, :, :4, :, :],
combined_condition[:, :, :4, :, :],
x_noisy[:, :, 4:, :, :],
combined_condition[:, :, 4:, :, :],
],
dim=2,
)
else:
assert (
not self.use_ray_decoder_regression
), "`separate_noise_and_condition` must be True when enabling `use_ray_decoder_regression`."
mask_preserving_target = condition_dict["mask_preserving_target"]
mask_preserving_condition = condition_dict["mask_preserving_condition"]
mask_for_combined_condition = torch.cat(
[
mask_preserving_target.repeat(1, 1, 4, 1, 1),
mask_preserving_condition.repeat(1, 1, 6, 1, 1),
]
)
mask_for_x_noisy = torch.cat(
[
mask_preserving_target.repeat(1, 1, 4, 1, 1),
mask_preserving_condition.repeat(1, 1, 6, 1, 1),
]
)
x_noisy = (
x_noisy * mask_for_x_noisy
+ combined_condition * mask_for_combined_condition
)
return x_noisy
def p_losses(self, x_start, cond, t, noise=None, **kwargs):
noise = default(noise, lambda: torch.randn_like(x_start))
if self.use_noise_offset:
noise = noise + 0.1 * torch.randn(
noise.shape[0], noise.shape[1], 1, 1, 1
).to(self.device)
# noise em !!!
if self.bd_noise:
noise_decor = self.bd(noise)
noise_decor = (noise_decor - noise_decor.mean()) / (
noise_decor.std() + 1e-5
)
noise_f = noise_decor[:, :, 0:1, :, :]
noise = (
np.sqrt(self.bd_ratio) * noise_decor[:, :, 1:]
+ np.sqrt(1 - self.bd_ratio) * noise_f
)
noise = torch.cat([noise_f, noise], dim=2)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
if self.train_with_multi_view_feature_alignment:
model_output, aligned_features = self.apply_model(
x_noisy, t, cond, **kwargs
)
aligned_middle_feature = rearrange(
aligned_features,
"(b t) c h w -> b (t c h w)",
b=cond["pts_anchor_to_all"].size(0),
t=cond["pts_anchor_to_all"].size(1),
)
target_multi_view_feature = rearrange(
torch.concat(
[cond["pts_anchor_to_all"], cond["pts_all_to_anchor"]], dim=2
),
"b t c h w -> b (t c h w)",
).to(aligned_middle_feature.device)
else:
model_output = self.apply_model(x_noisy, t, cond, **kwargs)
loss_dict = {}
prefix = "train" if self.training else "val"
if self.parameterization == "x0":
target = x_start
elif self.parameterization == "eps":
target = noise
elif self.parameterization == "v":
target = self.get_v(x_start, noise, t)
else:
raise NotImplementedError()
if self.apply_condition_mask_in_training_loss:
mask_full_for_output = 1.0 - cond["mask_full_for_input"]
model_output = model_output * mask_full_for_output
target = target * mask_full_for_output
loss_simple = self.get_loss(model_output, target, mean=False)
if self.ray_loss_weight != 1.0:
loss_simple[:, :, 4:, :, :] = (
loss_simple[:, :, 4:, :, :] * self.ray_loss_weight
)
if self.apply_condition_mask_in_training_loss:
# Ray loss: predicted items = # of condition images
num_total_images = cond["num_cond_images"] + cond["num_target_images"]
weight_for_image_loss = num_total_images / cond["num_target_images"]
weight_for_ray_loss = num_total_images / cond["num_cond_images"]
loss_simple[:, :, :4, :, :] = (
loss_simple[:, :, :4, :, :] * weight_for_image_loss
)
# Ray loss: predicted items = # of condition images
loss_simple[:, :, 4:, :, :] = (
loss_simple[:, :, 4:, :, :] * weight_for_ray_loss
)
loss_dict.update({f"{prefix}/loss_images": loss_simple[:, :, 0:4, :, :].mean()})
if not self.disable_ray_stream:
loss_dict.update(
{f"{prefix}/loss_rays": loss_simple[:, :, 4:, :, :].mean()}
)
loss_simple = loss_simple.mean([1, 2, 3, 4])
loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()})
if self.logvar.device is not self.device:
self.logvar = self.logvar.to(self.device)
logvar_t = self.logvar[t]
loss = loss_simple / torch.exp(logvar_t) + logvar_t
if self.learn_logvar:
loss_dict.update({f"{prefix}/loss_gamma": loss.mean()})
loss_dict.update({"logvar": self.logvar.data.mean()})
loss = self.l_simple_weight * loss.mean()
if self.train_with_multi_view_feature_alignment:
multi_view_feature_alignment_loss = 0.25 * torch.nn.functional.mse_loss(
aligned_middle_feature, target_multi_view_feature
)
loss += multi_view_feature_alignment_loss
loss_dict.update(
{f"{prefix}/loss_mv_feat_align": multi_view_feature_alignment_loss}
)
loss_vlb = self.get_loss(model_output, target, mean=False).mean(
dim=(1, 2, 3, 4)
)
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
loss_dict.update({f"{prefix}/loss_vlb": loss_vlb})
loss += self.original_elbo_weight * loss_vlb
loss_dict.update({f"{prefix}/loss": loss})
return loss, loss_dict
def _get_denoise_row_from_list(self, samples, desc=""):
denoise_row = []
for zd in tqdm(samples, desc=desc):
denoise_row.append(self.decode_first_stage(zd.to(self.device)))
n_log_time_steps = len(denoise_row)
denoise_row = torch.stack(denoise_row) # n_log_time_steps, b, C, H, W
if denoise_row.dim() == 5:
denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w")
denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
denoise_grid = make_grid(denoise_grid, nrow=n_log_time_steps)
elif denoise_row.dim() == 6:
video_length = denoise_row.shape[3]
denoise_grid = rearrange(denoise_row, "n b c t h w -> b n c t h w")
denoise_grid = rearrange(denoise_grid, "b n c t h w -> (b n) c t h w")
denoise_grid = rearrange(denoise_grid, "n c t h w -> (n t) c h w")
denoise_grid = make_grid(denoise_grid, nrow=video_length)
else:
raise ValueError
return denoise_grid
@torch.no_grad()
def log_images(
self,
batch,
sample=True,
ddim_steps=50,
ddim_eta=1.0,
plot_denoise_rows=False,
unconditional_guidance_scale=1.0,
**kwargs,
):
"""log images for LatentDiffusion"""
use_ddim = ddim_steps is not None
log = dict()
z, cond, x_rec = self.get_batch_input(
batch,
random_drop_training_conditions=False,
return_reconstructed_target_images=True,
)
b, t, c, h, w = x_rec.shape
log["num_cond_images_str"] = batch["num_cond_images_str"]
log["caption"] = batch["caption"]
if "condition_images" in batch:
log["input_condition_images_all"] = batch["condition_images"]
log["input_condition_image_latents_masked"] = cond["combined_condition"][
:, :, 0:3, :, :
]
log["input_condition_rays_o_masked"] = (
cond["combined_condition"][:, :, 4:7, :, :] / 5.0
)
log["input_condition_rays_d_masked"] = (
cond["combined_condition"][:, :, 7:, :, :] / 5.0
)
log["gt_images_after_vae"] = x_rec
if self.train_with_multi_view_feature_alignment:
log["pts_anchor_to_all"] = cond["pts_anchor_to_all"]
log["pts_all_to_anchor"] = cond["pts_all_to_anchor"]
log["pts_anchor_to_all"] = (
log["pts_anchor_to_all"] - torch.min(log["pts_anchor_to_all"])
) / torch.max(log["pts_anchor_to_all"])
log["pts_all_to_anchor"] = (
log["pts_all_to_anchor"] - torch.min(log["pts_all_to_anchor"])
) / torch.max(log["pts_all_to_anchor"])
if self.ray_as_image:
log["gt_rays_o"] = batch["combined_rays"][:, :, 0:3, :, :]
log["gt_rays_d"] = batch["combined_rays"][:, :, 3:, :, :]
else:
log["gt_rays_o"] = batch["combined_rays"][:, :, 0:3, :, :] / 5.0
log["gt_rays_d"] = batch["combined_rays"][:, :, 3:, :, :] / 5.0
if sample:
# get uncond embedding for classifier-free guidance sampling
if unconditional_guidance_scale != 1.0:
uc = self.get_unconditional_dict_for_sampling(batch, cond, x_rec)
else:
uc = None
with self.ema_scope("Plotting"):
out = self.sample_log(
cond=cond,
batch_size=b,
ddim=use_ddim,
ddim_steps=ddim_steps,
eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
mask=self.cond_mask,
x0=z,
with_extra_returned_data=False,
**kwargs,
)
samples, z_denoise_row = out
per_instance_decoding = False
if per_instance_decoding:
x_sample_images = []
for idx in range(b):
sample_image = samples[idx : idx + 1, :, 0:4, :, :]
x_sample_image = self.decode_first_stage(sample_image)
x_sample_images.append(x_sample_image)
x_sample_images = torch.cat(x_sample_images, dim=0)
else:
x_sample_images = self.decode_first_stage(samples[:, :, 0:4, :, :])
log["sample_images"] = x_sample_images
if not self.disable_ray_stream:
if self.ray_as_image:
log["sample_rays_o"] = self.decode_first_stage(
samples[:, :, 4:8, :, :]
)
log["sample_rays_d"] = self.decode_first_stage(
samples[:, :, 8:, :, :]
)
else:
log["sample_rays_o"] = samples[:, :, 4:7, :, :] / 5.0
log["sample_rays_d"] = samples[:, :, 7:, :, :] / 5.0
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
return log
def get_unconditional_dict_for_sampling(self, batch, cond, x_rec, is_extra=False):
b, t, c, h, w = x_rec.shape
if self.use_text_cross_attention_condition:
if self.uncond_type == "empty_seq":
# NVComposer's cross attention layers accept multi-view images
prompts = b * [""]
# prompts = b * t * [""] # if is_image_batch=True
uc_emb = self.get_learned_conditioning(prompts)
elif self.uncond_type == "zero_embed":
c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
uc_emb = torch.zeros_like(c_emb)
else:
uc_emb = None
# process image condition
if not is_extra:
if hasattr(self, "embedder"):
# uc_img = torch.zeros_like(x[:, :, 0, ...]) # b c h w
uc_img = torch.zeros(
# b c h w
size=(b * t, c, h, w),
dtype=x_rec.dtype,
device=x_rec.device,
)
# img: b c h w >> b l c
uc_img = self.get_image_embeds(uc_img, batch)
# Modified: The uc embeddings should be reshaped for valid post-processing
uc_img = rearrange(
uc_img, "(b t) s d -> b (t s) d", b=b, t=uc_img.shape[0] // b
)
if uc_emb is None:
uc_emb = uc_img
else:
uc_emb = torch.cat([uc_emb, uc_img], dim=1)
uc = {key: cond[key] for key in cond.keys()}
uc.update({"c_crossattn": [uc_emb]})
else:
uc = {key: cond[key] for key in cond.keys()}
uc.update({"combined_condition": uc["uncond_combined_condition"]})
return uc
def p_mean_variance(
self,
x,
c,
t,
clip_denoised: bool,
return_x0=False,
score_corrector=None,
corrector_kwargs=None,
**kwargs,
):
t_in = t
model_out = self.apply_model(x, t_in, c, **kwargs)
if score_corrector is not None:
assert self.parameterization == "eps"
model_out = score_corrector.modify_score(
self, model_out, x, t, c, **corrector_kwargs
)
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
else:
raise NotImplementedError()
if clip_denoised:
x_recon.clamp_(-1.0, 1.0)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start=x_recon, x_t=x, t=t
)
if return_x0:
return model_mean, posterior_variance, posterior_log_variance, x_recon
else:
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(
self,
x,
c,
t,
clip_denoised=False,
repeat_noise=False,
return_x0=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
**kwargs,
):
b, *_, device = *x.shape, x.device
outputs = self.p_mean_variance(
x=x,
c=c,
t=t,
clip_denoised=clip_denoised,
return_x0=return_x0,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
**kwargs,
)
if return_x0:
model_mean, _, model_log_variance, x0 = outputs
else:
model_mean, _, model_log_variance = outputs
noise = noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
if return_x0:
return (
model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
x0,
)
else:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(
self,
cond,
shape,
return_intermediates=False,
x_T=None,
verbose=True,
callback=None,
time_steps=None,
mask=None,
x0=None,
img_callback=None,
start_T=None,
log_every_t=None,
**kwargs,
):
if not log_every_t:
log_every_t = self.log_every_t
device = self.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
intermediates = [img]
if time_steps is None:
time_steps = self.num_time_steps
if start_T is not None:
time_steps = min(time_steps, start_T)
iterator = (
tqdm(reversed(range(0, time_steps)), desc="Sampling t", total=time_steps)
if verbose
else reversed(range(0, time_steps))
)
if mask is not None:
assert x0 is not None
# spatial size has to match
assert x0.shape[2:3] == mask.shape[2:3]
for i in iterator:
ts = torch.full((b,), i, device=device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != "hybrid"
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img = self.p_sample(
img, cond, ts, clip_denoised=self.clip_denoised, **kwargs
)
if mask is not None:
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1.0 - mask) * img
if i % log_every_t == 0 or i == time_steps - 1:
intermediates.append(img)
if callback:
callback(i)
if img_callback:
img_callback(img, i)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(
self,
cond,
batch_size=16,
return_intermediates=False,
x_T=None,
verbose=True,
time_steps=None,
mask=None,
x0=None,
shape=None,
**kwargs,
):
if shape is None:
shape = (batch_size, self.channels, self.temporal_length, *self.image_size)
if cond is not None:
if isinstance(cond, dict):
cond = {
key: (
cond[key][:batch_size]
if not isinstance(cond[key], list)
else list(map(lambda x: x[:batch_size], cond[key]))
)
for key in cond
}
else:
cond = (
[c[:batch_size] for c in cond]
if isinstance(cond, list)
else cond[:batch_size]
)
return self.p_sample_loop(
cond,
shape,
return_intermediates=return_intermediates,
x_T=x_T,
verbose=verbose,
time_steps=time_steps,
mask=mask,
x0=x0,
**kwargs,
)
@torch.no_grad()
def sample_log(
self,
cond,
batch_size,
ddim,
ddim_steps,
with_extra_returned_data=False,
**kwargs,
):
if ddim:
ddim_sampler = DDIMSampler(self)
shape = (self.temporal_length, self.channels, *self.image_size)
out = ddim_sampler.sample(
ddim_steps,
batch_size,
shape,
cond,
verbose=True,
with_extra_returned_data=with_extra_returned_data,
**kwargs,
)
if with_extra_returned_data:
samples, intermediates, extra_returned_data = out
return samples, intermediates, extra_returned_data
else:
samples, intermediates = out
return samples, intermediates
else:
samples, intermediates = self.sample(
cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs
)
return samples, intermediates
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config):
super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config)
def forward(self, x, c, **kwargs):
return self.diffusion_model(x, c, **kwargs)
"""SAMPLING ONLY."""
import numpy as np
import torch
from einops import rearrange
from tqdm import tqdm
from core.common import noise_like
from core.models.utils_diffusion import (
make_ddim_sampling_parameters,
make_ddim_time_steps,
rescale_noise_cfg,
)
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_time_steps = model.num_time_steps
self.schedule = schedule
self.counter = 0
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
self.ddim_time_steps = make_ddim_time_steps(
ddim_discr_method=ddim_discretize,
num_ddim_time_steps=ddim_num_steps,
num_ddpm_time_steps=self.ddpm_num_time_steps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_time_steps
), "alphas have to be defined for each timestep"
def to_torch(x):
return x.clone().detach().to(torch.float32).to(self.model.device)
if self.model.use_dynamic_rescale:
self.ddim_scale_arr = self.model.scale_arr[self.ddim_time_steps]
self.ddim_scale_arr_prev = torch.cat(
[self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]]
)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_time_steps=self.ddim_time_steps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
schedule_verbose=False,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
unconditional_guidance_scale_extra=1.0,
unconditional_conditioning_extra=None,
with_extra_returned_data=False,
**kwargs,
):
# check condition bs
if conditioning is not None:
if isinstance(conditioning, dict):
try:
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
except:
cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
if cbs != batch_size:
print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
print(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.skip_step = self.ddpm_num_time_steps // S
discr_method = (
"uniform_trailing" if self.model.rescale_betas_zero_snr else "uniform"
)
self.make_schedule(
ddim_num_steps=S,
ddim_discretize=discr_method,
ddim_eta=eta,
verbose=schedule_verbose,
)
# make shape
if len(shape) == 3:
C, H, W = shape
size = (batch_size, C, H, W)
elif len(shape) == 4:
T, C, H, W = shape
size = (batch_size, T, C, H, W)
else:
assert False, f"Invalid shape: {shape}."
out = self.ddim_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
unconditional_guidance_scale_extra=unconditional_guidance_scale_extra,
unconditional_conditioning_extra=unconditional_conditioning_extra,
verbose=verbose,
with_extra_returned_data=with_extra_returned_data,
**kwargs,
)
if with_extra_returned_data:
samples, intermediates, extra_returned_data = out
return samples, intermediates, extra_returned_data
else:
samples, intermediates = out
return samples, intermediates
@torch.no_grad()
def ddim_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
time_steps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
unconditional_guidance_scale_extra=1.0,
unconditional_conditioning_extra=None,
verbose=True,
with_extra_returned_data=False,
**kwargs,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device, dtype=self.model.dtype)
if self.model.bd_noise:
noise_decor = self.model.bd(img)
noise_decor = (noise_decor - noise_decor.mean()) / (
noise_decor.std() + 1e-5
)
noise_f = noise_decor[:, :, 0:1, :, :]
noise = (
np.sqrt(self.model.bd_ratio) * noise_decor[:, :, 1:]
+ np.sqrt(1 - self.model.bd_ratio) * noise_f
)
img = torch.cat([noise_f, noise], dim=2)
else:
img = x_T
if time_steps is None:
time_steps = (
self.ddpm_num_time_steps
if ddim_use_original_steps
else self.ddim_time_steps
)
elif time_steps is not None and not ddim_use_original_steps:
subset_end = (
int(
min(time_steps / self.ddim_time_steps.shape[0], 1)
* self.ddim_time_steps.shape[0]
)
- 1
)
time_steps = self.ddim_time_steps[:subset_end]
intermediates = {"x_inter": [img], "pred_x0": [img]}
time_range = (
reversed(range(0, time_steps))
if ddim_use_original_steps
else np.flip(time_steps)
)
total_steps = time_steps if ddim_use_original_steps else time_steps.shape[0]
if verbose:
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
else:
iterator = time_range
# Sampling Loop
for i, step in enumerate(iterator):
print(f"Sample: i={i}, step={step}.")
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
print("ts=", ts)
# use mask to blend noised original latent (img_orig) & new sampled latent (img)
if mask is not None:
assert x0 is not None
img_orig = x0
# keep original & modify use img
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
unconditional_guidance_scale_extra=unconditional_guidance_scale_extra,
unconditional_conditioning_extra=unconditional_conditioning_extra,
with_extra_returned_data=with_extra_returned_data,
**kwargs,
)
if with_extra_returned_data:
img, pred_x0, extra_returned_data = outs
else:
img, pred_x0 = outs
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
# log_every_t = 1
if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)
intermediates["pred_x0"].append(pred_x0)
# intermediates['extra_returned_data'].append(extra_returned_data)
if with_extra_returned_data:
return img, intermediates, extra_returned_data
return img, intermediates
def batch_time_transpose(
self, batch_time_tensor, num_target_views, num_condition_views
):
# Input: N*N; N = T+C
assert num_target_views + num_condition_views == batch_time_tensor.shape[1]
target_tensor = batch_time_tensor[:, :num_target_views, ...] # T*T
condition_tensor = batch_time_tensor[:, num_target_views:, ...] # N*C
target_tensor = target_tensor.transpose(0, 1) # T*T
return torch.concat([target_tensor, condition_tensor], dim=1)
def ddim_batch_shard_step(
self,
pred_x0_post_process_function,
pred_x0_post_process_function_kwargs,
cond,
corrector_kwargs,
ddim_use_original_steps,
device,
img,
index,
kwargs,
noise_dropout,
quantize_denoised,
score_corrector,
step,
temperature,
with_extra_returned_data,
):
img_list = []
pred_x0_list = []
shard_step = 5
shard_start = 0
while shard_start < img.shape[0]:
shard_end = shard_start + shard_step
if shard_start >= img.shape[0]:
break
if shard_end > img.shape[0]:
shard_end = img.shape[0]
print(
f"Sampling Batch Shard: From #{shard_start} to #{shard_end}. Total: {img.shape[0]}."
)
sub_img = img[shard_start:shard_end]
sub_cond = {
"combined_condition": cond["combined_condition"][shard_start:shard_end],
"c_crossattn": [
cond["c_crossattn"][0][0:1].expand(shard_end - shard_start, -1, -1)
],
}
ts = torch.full((sub_img.shape[0],), step, device=device, dtype=torch.long)
_img, _pred_x0 = self.p_sample_ddim(
sub_img,
sub_cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
unconditional_guidance_scale_extra=1.0,
unconditional_conditioning_extra=None,
pred_x0_post_process_function=pred_x0_post_process_function,
pred_x0_post_process_function_kwargs=pred_x0_post_process_function_kwargs,
with_extra_returned_data=with_extra_returned_data,
**kwargs,
)
img_list.append(_img)
pred_x0_list.append(_pred_x0)
shard_start += shard_step
img = torch.concat(img_list, dim=0)
pred_x0 = torch.concat(pred_x0_list, dim=0)
return img, pred_x0
@torch.no_grad()
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
unconditional_guidance_scale_extra=1.0,
unconditional_conditioning_extra=None,
with_extra_returned_data=False,
**kwargs,
):
b, *_, device = *x.shape, x.device
if x.dim() == 5:
is_video = True
else:
is_video = False
extra_returned_data = None
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
e_t_cfg = self.model.apply_model(x, t, c, **kwargs) # unet denoiser
if isinstance(e_t_cfg, tuple):
e_t_cfg = e_t_cfg[0]
extra_returned_data = e_t_cfg[1:]
else:
# with unconditional condition
if isinstance(c, torch.Tensor) or isinstance(c, dict):
e_t = self.model.apply_model(x, t, c, **kwargs)
e_t_uncond = self.model.apply_model(
x, t, unconditional_conditioning, **kwargs
)
if (
unconditional_guidance_scale_extra != 1.0
and unconditional_conditioning_extra is not None
):
print(f"Using extra CFG: {unconditional_guidance_scale_extra}...")
e_t_uncond_extra = self.model.apply_model(
x, t, unconditional_conditioning_extra, **kwargs
)
else:
e_t_uncond_extra = None
else:
raise NotImplementedError
if isinstance(e_t, tuple):
e_t = e_t[0]
extra_returned_data = e_t[1:]
if isinstance(e_t_uncond, tuple):
e_t_uncond = e_t_uncond[0]
if isinstance(e_t_uncond_extra, tuple):
e_t_uncond_extra = e_t_uncond_extra[0]
# text cfg
if (
unconditional_guidance_scale_extra != 1.0
and unconditional_conditioning_extra is not None
):
e_t_cfg = (
e_t_uncond
+ unconditional_guidance_scale * (e_t - e_t_uncond)
+ unconditional_guidance_scale_extra * (e_t - e_t_uncond_extra)
)
else:
e_t_cfg = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if self.model.rescale_betas_zero_snr:
e_t_cfg = rescale_noise_cfg(e_t_cfg, e_t, guidance_rescale=0.7)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, e_t_cfg)
else:
e_t = e_t_cfg
if score_corrector is not None:
assert self.model.parameterization == "eps", "not implemented"
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
if is_video:
size = (b, 1, 1, 1, 1)
else:
size = (b, 1, 1, 1)
a_t = torch.full(size, alphas[index], device=device)
a_prev = torch.full(size, alphas_prev[index], device=device)
sigma_t = torch.full(size, sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
size, sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, e_t_cfg)
if self.model.use_dynamic_rescale:
scale_t = torch.full(size, self.ddim_scale_arr[index], device=device)
prev_scale_t = torch.full(
size, self.ddim_scale_arr_prev[index], device=device
)
rescale = prev_scale_t / scale_t
pred_x0 *= rescale
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = noise_like(x.shape, device, repeat_noise)
if self.model.bd_noise:
noise_decor = self.model.bd(noise)
noise_decor = (noise_decor - noise_decor.mean()) / (
noise_decor.std() + 1e-5
)
noise_f = noise_decor[:, :, 0:1, :, :]
noise = (
np.sqrt(self.model.bd_ratio) * noise_decor[:, :, 1:]
+ np.sqrt(1 - self.model.bd_ratio) * noise_f
)
noise = torch.cat([noise_f, noise], dim=2)
noise = sigma_t * noise * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
if with_extra_returned_data:
return x_prev, pred_x0, extra_returned_data
return x_prev, pred_x0
from .sampler import DPMSolverSampler
\ No newline at end of file
import torch
import torch.nn.functional as F
import math
from tqdm import tqdm
class NoiseScheduleVP:
def __init__(
self,
schedule='discrete',
betas=None,
alphas_cumprod=None,
continuous_beta_0=0.1,
continuous_beta_1=20.,
):
"""Create a wrapper class for the forward SDE (VP type).
***
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
***
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
log_alpha_t = self.marginal_log_mean_coeff(t)
sigma_t = self.marginal_std(t)
lambda_t = self.marginal_lambda(t)
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
t = self.inverse_lambda(lambda_t)
===============================================================
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
1. For discrete-time DPMs:
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
t_i = (i + 1) / N
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
Args:
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
**Important**: Please pay special attention for the args for `alphas_cumprod`:
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
alpha_{t_n} = \sqrt{\hat{alpha_n}},
and
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
2. For continuous-time DPMs:
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
schedule are the default settings in DDPM and improved-DDPM:
Args:
beta_min: A `float` number. The smallest beta for the linear schedule.
beta_max: A `float` number. The largest beta for the linear schedule.
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
T: A `float` number. The ending time of the forward process.
===============================================================
Args:
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
'linear' or 'cosine' for continuous-time DPMs.
Returns:
A wrapper object of the forward SDE (VP type).
===============================================================
Example:
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', betas=betas)
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
# For continuous-time DPMs (VPSDE), linear schedule:
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
"""
if schedule not in ['discrete', 'linear', 'cosine']:
raise ValueError(
"Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
schedule))
self.schedule = schedule
if schedule == 'discrete':
if betas is not None:
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
else:
assert alphas_cumprod is not None
log_alphas = 0.5 * torch.log(alphas_cumprod)
self.total_N = len(log_alphas)
self.T = 1.
self.t_array = torch.linspace(
0., 1., self.total_N + 1)[1:].reshape((1, -1))
self.log_alpha_array = log_alphas.reshape((1, -1,))
else:
self.total_N = 1000
self.beta_0 = continuous_beta_0
self.beta_1 = continuous_beta_1
self.cosine_s = 0.008
self.cosine_beta_max = 999.
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
1. + self.cosine_s) / math.pi - self.cosine_s
self.cosine_log_alpha_0 = math.log(
math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
self.schedule = schedule
if schedule == 'cosine':
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
self.T = 0.9946
else:
self.T = 1.
def marginal_log_mean_coeff(self, t):
"""
Compute log(alpha_t) of a given continuous-time label t in [0, T].
"""
if self.schedule == 'discrete':
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
self.log_alpha_array.to(t.device)).reshape((-1))
elif self.schedule == 'linear':
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
elif self.schedule == 'cosine':
def log_alpha_fn(s): return torch.log(
torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
return log_alpha_t
def marginal_alpha(self, t):
"""
Compute alpha_t of a given continuous-time label t in [0, T].
"""
return torch.exp(self.marginal_log_mean_coeff(t))
def marginal_std(self, t):
"""
Compute sigma_t of a given continuous-time label t in [0, T].
"""
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
def marginal_lambda(self, t):
"""
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
"""
log_mean_coeff = self.marginal_log_mean_coeff(t)
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
return log_mean_coeff - log_std
def inverse_lambda(self, lamb):
"""
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
"""
if self.schedule == 'linear':
tmp = 2. * (self.beta_1 - self.beta_0) * \
torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
Delta = self.beta_0 ** 2 + tmp
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
elif self.schedule == 'discrete':
log_alpha = -0.5 * \
torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
torch.flip(self.t_array.to(lamb.device), [1]))
return t.reshape((-1,))
else:
log_alpha = -0.5 * \
torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
def t_fn(log_alpha_t): return torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
1. + self.cosine_s) / math.pi - self.cosine_s
t = t_fn(log_alpha)
return t
def model_wrapper(
model,
noise_schedule,
model_type="noise",
model_kwargs={},
guidance_type="uncond",
condition=None,
unconditional_condition=None,
guidance_scale=1.,
classifier_fn=None,
classifier_kwargs={},
):
"""Create a wrapper function for the noise prediction model.
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
We support four types of the diffusion model by setting `model_type`:
1. "noise": noise prediction model. (Trained by predicting noise).
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
3. "v": velocity prediction model. (Trained by predicting the velocity).
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
arXiv preprint arXiv:2202.00512 (2022).
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
arXiv preprint arXiv:2210.02303 (2022).
4. "score": marginal score function. (Trained by denoising score matching).
Note that the score function and the noise prediction model follows a simple relationship:
```
noise(x_t, t) = -sigma_t * score(x_t, t)
```
We support three types of guided sampling by DPMs by setting `guidance_type`:
1. "uncond": unconditional sampling by DPMs.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
The input `classifier_fn` has the following format:
``
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
``
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
The input `model` has the following format:
``
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
``
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
arXiv preprint arXiv:2207.12598 (2022).
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
or continuous-time labels (i.e. epsilon to T).
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
``
def model_fn(x, t_continuous) -> noise:
t_input = get_model_input_time(t_continuous)
return noise_pred(model, x, t_input, **model_kwargs)
``
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
===============================================================
Args:
model: A diffusion model with the corresponding format described above.
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
model_type: A `str`. The parameterization type of the diffusion model.
"noise" or "x_start" or "v" or "score".
model_kwargs: A `dict`. A dict for the other inputs of the model function.
guidance_type: A `str`. The type of the guidance for sampling.
"uncond" or "classifier" or "classifier-free".
condition: A pytorch tensor. The condition for the guided sampling.
Only used for "classifier" or "classifier-free" guidance type.
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
Only used for "classifier-free" guidance type.
guidance_scale: A `float`. The scale for the guided sampling.
classifier_fn: A classifier function. Only used for the classifier guidance.
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
Returns:
A noise prediction model that accepts the noised data and the continuous time as the inputs.
"""
def get_model_input_time(t_continuous):
"""
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
For continuous-time DPMs, we just use `t_continuous`.
"""
if noise_schedule.schedule == 'discrete':
return (t_continuous - 1. / noise_schedule.total_N) * 1000.
else:
return t_continuous
def noise_pred_fn(x, t_continuous, cond=None):
if t_continuous.reshape((-1,)).shape[0] == 1:
t_continuous = t_continuous.expand((x.shape[0]))
t_input = get_model_input_time(t_continuous)
if cond is None:
output = model(x, t_input, **model_kwargs)
else:
output = model(x, t_input, cond, **model_kwargs)
if isinstance(output, tuple):
output = output[0]
if model_type == "noise":
return output
elif model_type == "x_start":
alpha_t, sigma_t = noise_schedule.marginal_alpha(
t_continuous), noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
elif model_type == "v":
alpha_t, sigma_t = noise_schedule.marginal_alpha(
t_continuous), noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
elif model_type == "score":
sigma_t = noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return -expand_dims(sigma_t, dims) * output
def cond_grad_fn(x, t_input):
"""
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
"""
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
log_prob = classifier_fn(
x_in, t_input, condition, **classifier_kwargs)
return torch.autograd.grad(log_prob.sum(), x_in)[0]
def model_fn(x, t_continuous):
"""
The noise predicition model function that is used for DPM-Solver.
"""
if t_continuous.reshape((-1,)).shape[0] == 1:
t_continuous = t_continuous.expand((x.shape[0]))
if guidance_type == "uncond":
return noise_pred_fn(x, t_continuous)
elif guidance_type == "classifier":
assert classifier_fn is not None
t_input = get_model_input_time(t_continuous)
cond_grad = cond_grad_fn(x, t_input)
sigma_t = noise_schedule.marginal_std(t_continuous)
noise = noise_pred_fn(x, t_continuous)
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
elif guidance_type == "classifier-free":
if guidance_scale == 1. or unconditional_condition is None:
return noise_pred_fn(x, t_continuous, cond=condition)
else:
# x_in = torch.cat([x] * 2)
# t_in = torch.cat([t_continuous] * 2)
x_in = x
t_in = t_continuous
# c_in = torch.cat([unconditional_condition, condition])
noise = noise_pred_fn(x_in, t_in, cond=condition)
noise_uncond = noise_pred_fn(
x_in, t_in, cond=unconditional_condition)
return noise_uncond + guidance_scale * (noise - noise_uncond)
assert model_type in ["noise", "x_start", "v"]
assert guidance_type in ["uncond", "classifier", "classifier-free"]
return model_fn
class DPM_Solver:
def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
"""Construct a DPM-Solver.
We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
Args:
model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
``
def model_fn(x, t_continuous):
return noise
``
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
[1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
"""
self.model = model_fn
self.noise_schedule = noise_schedule
self.predict_x0 = predict_x0
self.thresholding = thresholding
self.max_val = max_val
def noise_prediction_fn(self, x, t):
"""
Return the noise prediction model.
"""
return self.model(x, t)
def data_prediction_fn(self, x, t):
"""
Return the data prediction model (with thresholding).
"""
noise = self.noise_prediction_fn(x, t)
dims = x.dim()
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(
t), self.noise_schedule.marginal_std(t)
x0 = (x - expand_dims(sigma_t, dims) * noise) / \
expand_dims(alpha_t, dims)
if self.thresholding:
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
s = torch.quantile(torch.abs(x0).reshape(
(x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.max_val *
torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
return x0
def model_fn(self, x, t):
"""
Convert the model to the noise prediction model or the data prediction model.
"""
if self.predict_x0:
return self.data_prediction_fn(x, t)
else:
return self.noise_prediction_fn(x, t)
def get_time_steps(self, skip_type, t_T, t_0, N, device):
"""Compute the intermediate time steps for sampling.
Args:
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
- 'logSNR': uniform logSNR for the time steps.
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
N: A `int`. The total number of the spacing of the time steps.
device: A torch device.
Returns:
A pytorch tensor of the time steps, with the shape (N + 1,).
"""
if skip_type == 'logSNR':
lambda_T = self.noise_schedule.marginal_lambda(
torch.tensor(t_T).to(device))
lambda_0 = self.noise_schedule.marginal_lambda(
torch.tensor(t_0).to(device))
logSNR_steps = torch.linspace(
lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
return self.noise_schedule.inverse_lambda(logSNR_steps)
elif skip_type == 'time_uniform':
return torch.linspace(t_T, t_0, N + 1).to(device)
elif skip_type == 'time_quadratic':
t_order = 2
t = torch.linspace(t_T ** (1. / t_order), t_0 **
(1. / t_order), N + 1).pow(t_order).to(device)
return t
else:
raise ValueError(
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
def get_orders_and_time_steps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
"""
Get the order of each step for sampling by the singlestep DPM-Solver.
We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
- If order == 1:
We take `steps` of DPM-Solver-1 (i.e. DDIM).
- If order == 2:
- Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
- If steps % 2 == 0, we use K steps of DPM-Solver-2.
- If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
- If order == 3:
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
============================================
Args:
order: A `int`. The max order for the solver (2 or 3).
steps: A `int`. The total number of function evaluations (NFE).
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
- 'logSNR': uniform logSNR for the time steps.
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
device: A torch device.
Returns:
orders: A list of the solver order of each step.
"""
if order == 3:
K = steps // 3 + 1
if steps % 3 == 0:
orders = [3, ] * (K - 2) + [2, 1]
elif steps % 3 == 1:
orders = [3, ] * (K - 1) + [1]
else:
orders = [3, ] * (K - 1) + [2]
elif order == 2:
if steps % 2 == 0:
K = steps // 2
orders = [2, ] * K
else:
K = steps // 2 + 1
orders = [2, ] * (K - 1) + [1]
elif order == 1:
K = 1
orders = [1, ] * steps
else:
raise ValueError("'order' must be '1' or '2' or '3'.")
if skip_type == 'logSNR':
# To reproduce the results in DPM-Solver paper
time_steps_outer = self.get_time_steps(
skip_type, t_T, t_0, K, device)
else:
time_steps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
return time_steps_outer, orders
def denoise_to_zero_fn(self, x, s):
"""
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
"""
return self.data_prediction_fn(x, s)
def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
"""
DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s`.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
ns = self.noise_schedule
dims = x.dim()
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(
s), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
if self.predict_x0:
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = self.model_fn(x, s)
x_t = (
expand_dims(sigma_t / sigma_s, dims) * x
- expand_dims(alpha_t * phi_1, dims) * model_s
)
if return_intermediate:
return x_t, {'model_s': model_s}
else:
return x_t
else:
phi_1 = torch.expm1(h)
if model_s is None:
model_s = self.model_fn(x, s)
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
- expand_dims(sigma_t * phi_1, dims) * model_s
)
if return_intermediate:
return x_t, {'model_s': model_s}
else:
return x_t
def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
solver_type='dpm_solver'):
"""
Singlestep solver DPM-Solver-2 from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
r1: A `float`. The hyperparameter of the second-order solver.
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ['dpm_solver', 'taylor']:
raise ValueError(
"'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
if r1 is None:
r1 = 0.5
ns = self.noise_schedule
dims = x.dim()
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
lambda_s1 = lambda_s + r1 * h
s1 = ns.inverse_lambda(lambda_s1)
log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
s1), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_s1, sigma_t = ns.marginal_std(
s), ns.marginal_std(s1), ns.marginal_std(t)
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
if self.predict_x0:
phi_11 = torch.expm1(-r1 * h)
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = self.model_fn(x, s)
x_s1 = (
expand_dims(sigma_s1 / sigma_s, dims) * x
- expand_dims(alpha_s1 * phi_11, dims) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
if solver_type == 'dpm_solver':
x_t = (
expand_dims(sigma_t / sigma_s, dims) * x
- expand_dims(alpha_t * phi_1, dims) * model_s
- (0.5 / r1) * expand_dims(alpha_t *
phi_1, dims) * (model_s1 - model_s)
)
elif solver_type == 'taylor':
x_t = (
expand_dims(sigma_t / sigma_s, dims) * x
- expand_dims(alpha_t * phi_1, dims) * model_s
+ (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
model_s1 - model_s)
)
else:
phi_11 = torch.expm1(r1 * h)
phi_1 = torch.expm1(h)
if model_s is None:
model_s = self.model_fn(x, s)
x_s1 = (
expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
- expand_dims(sigma_s1 * phi_11, dims) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
if solver_type == 'dpm_solver':
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
- expand_dims(sigma_t * phi_1, dims) * model_s
- (0.5 / r1) * expand_dims(sigma_t *
phi_1, dims) * (model_s1 - model_s)
)
elif solver_type == 'taylor':
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
- expand_dims(sigma_t * phi_1, dims) * model_s
- (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) -
1.) / h - 1.), dims) * (model_s1 - model_s)
)
if return_intermediate:
return x_t, {'model_s': model_s, 'model_s1': model_s1}
else:
return x_t
def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
return_intermediate=False, solver_type='dpm_solver'):
"""
Singlestep solver DPM-Solver-3 from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
r1: A `float`. The hyperparameter of the third-order solver.
r2: A `float`. The hyperparameter of the third-order solver.
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ['dpm_solver', 'taylor']:
raise ValueError(
"'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
if r1 is None:
r1 = 1. / 3.
if r2 is None:
r2 = 2. / 3.
ns = self.noise_schedule
dims = x.dim()
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
lambda_s1 = lambda_s + r1 * h
lambda_s2 = lambda_s + r2 * h
s1 = ns.inverse_lambda(lambda_s1)
s2 = ns.inverse_lambda(lambda_s2)
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
s2), ns.marginal_std(t)
alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(
log_alpha_s2), torch.exp(log_alpha_t)
if self.predict_x0:
phi_11 = torch.expm1(-r1 * h)
phi_12 = torch.expm1(-r2 * h)
phi_1 = torch.expm1(-h)
phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
phi_2 = phi_1 / h + 1.
phi_3 = phi_2 / h - 0.5
if model_s is None:
model_s = self.model_fn(x, s)
if model_s1 is None:
x_s1 = (
expand_dims(sigma_s1 / sigma_s, dims) * x
- expand_dims(alpha_s1 * phi_11, dims) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
x_s2 = (
expand_dims(sigma_s2 / sigma_s, dims) * x
- expand_dims(alpha_s2 * phi_12, dims) * model_s
+ r2 / r1 * expand_dims(alpha_s2 * phi_22,
dims) * (model_s1 - model_s)
)
model_s2 = self.model_fn(x_s2, s2)
if solver_type == 'dpm_solver':
x_t = (
expand_dims(sigma_t / sigma_s, dims) * x
- expand_dims(alpha_t * phi_1, dims) * model_s
+ (1. / r2) * expand_dims(alpha_t *
phi_2, dims) * (model_s2 - model_s)
)
elif solver_type == 'taylor':
D1_0 = (1. / r1) * (model_s1 - model_s)
D1_1 = (1. / r2) * (model_s2 - model_s)
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
x_t = (
expand_dims(sigma_t / sigma_s, dims) * x
- expand_dims(alpha_t * phi_1, dims) * model_s
+ expand_dims(alpha_t * phi_2, dims) * D1
- expand_dims(alpha_t * phi_3, dims) * D2
)
else:
phi_11 = torch.expm1(r1 * h)
phi_12 = torch.expm1(r2 * h)
phi_1 = torch.expm1(h)
phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
phi_2 = phi_1 / h - 1.
phi_3 = phi_2 / h - 0.5
if model_s is None:
model_s = self.model_fn(x, s)
if model_s1 is None:
x_s1 = (
expand_dims(
torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
- expand_dims(sigma_s1 * phi_11, dims) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
x_s2 = (
expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
- expand_dims(sigma_s2 * phi_12, dims) * model_s
- r2 / r1 * expand_dims(sigma_s2 *
phi_22, dims) * (model_s1 - model_s)
)
model_s2 = self.model_fn(x_s2, s2)
if solver_type == 'dpm_solver':
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
- expand_dims(sigma_t * phi_1, dims) * model_s
- (1. / r2) * expand_dims(sigma_t *
phi_2, dims) * (model_s2 - model_s)
)
elif solver_type == 'taylor':
D1_0 = (1. / r1) * (model_s1 - model_s)
D1_1 = (1. / r2) * (model_s2 - model_s)
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
- expand_dims(sigma_t * phi_1, dims) * model_s
- expand_dims(sigma_t * phi_2, dims) * D1
- expand_dims(sigma_t * phi_3, dims) * D2
)
if return_intermediate:
return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
else:
return x_t
def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
"""
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ['dpm_solver', 'taylor']:
raise ValueError(
"'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
ns = self.noise_schedule
dims = x.dim()
model_prev_1, model_prev_0 = model_prev_list
t_prev_1, t_prev_0 = t_prev_list
lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
t_prev_0), ns.marginal_lambda(t)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0 = h_0 / h
D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
if self.predict_x0:
if solver_type == 'dpm_solver':
x_t = (
expand_dims(sigma_t / sigma_prev_0, dims) * x
- expand_dims(alpha_t * (torch.exp(-h) - 1.),
dims) * model_prev_0
- 0.5 * expand_dims(alpha_t *
(torch.exp(-h) - 1.), dims) * D1_0
)
elif solver_type == 'taylor':
x_t = (
expand_dims(sigma_t / sigma_prev_0, dims) * x
- expand_dims(alpha_t * (torch.exp(-h) - 1.),
dims) * model_prev_0
+ expand_dims(alpha_t *
((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
)
else:
if solver_type == 'dpm_solver':
x_t = (
expand_dims(
torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * (torch.exp(h) - 1.),
dims) * model_prev_0
- 0.5 * expand_dims(sigma_t *
(torch.exp(h) - 1.), dims) * D1_0
)
elif solver_type == 'taylor':
x_t = (
expand_dims(
torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * (torch.exp(h) - 1.),
dims) * model_prev_0
- expand_dims(sigma_t * ((torch.exp(h) -
1.) / h - 1.), dims) * D1_0
)
return x_t
def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
"""
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
ns = self.noise_schedule
dims = x.dim()
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_1 = lambda_prev_1 - lambda_prev_2
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0, r1 = h_0 / h, h_1 / h
D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
if self.predict_x0:
x_t = (
expand_dims(sigma_t / sigma_prev_0, dims) * x
- expand_dims(alpha_t * (torch.exp(-h) - 1.),
dims) * model_prev_0
+ expand_dims(alpha_t *
((torch.exp(-h) - 1.) / h + 1.), dims) * D1
- expand_dims(alpha_t * ((torch.exp(-h) -
1. + h) / h ** 2 - 0.5), dims) * D2
)
else:
x_t = (
expand_dims(
torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * (torch.exp(h) - 1.),
dims) * model_prev_0
- expand_dims(sigma_t *
((torch.exp(h) - 1.) / h - 1.), dims) * D1
- expand_dims(sigma_t * ((torch.exp(h) -
1. - h) / h ** 2 - 0.5), dims) * D2
)
return x_t
def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
r2=None):
"""
Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
r1: A `float`. The hyperparameter of the second-order or third-order solver.
r2: A `float`. The hyperparameter of the third-order solver.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if order == 1:
return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
elif order == 2:
return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
solver_type=solver_type, r1=r1)
elif order == 3:
return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
solver_type=solver_type, r1=r1, r2=r2)
else:
raise ValueError(
"Solver order must be 1 or 2 or 3, got {}".format(order))
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
"""
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if order == 1:
return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
elif order == 2:
return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
elif order == 3:
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
else:
raise ValueError(
"Solver order must be 1 or 2 or 3, got {}".format(order))
def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
solver_type='dpm_solver'):
"""
The adaptive step size solver based on singlestep DPM-Solver.
Args:
x: A pytorch tensor. The initial value at time `t_T`.
order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
h_init: A `float`. The initial step size (for logSNR).
atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
current time and `t_0` is less than `t_err`. The default setting is 1e-5.
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_0: A pytorch tensor. The approximated solution at time `t_0`.
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
"""
ns = self.noise_schedule
s = t_T * torch.ones((x.shape[0],)).to(x)
lambda_s = ns.marginal_lambda(s)
lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
h = h_init * torch.ones_like(s).to(x)
x_prev = x
nfe = 0
if order == 2:
r1 = 0.5
def lower_update(x, s, t): return self.dpm_solver_first_update(
x, s, t, return_intermediate=True)
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
solver_type=solver_type,
**kwargs)
elif order == 3:
r1, r2 = 1. / 3., 2. / 3.
def lower_update(x, s, t): return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
return_intermediate=True,
solver_type=solver_type)
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
solver_type=solver_type,
**kwargs)
else:
raise ValueError(
"For adaptive step size solver, order must be 2 or 3, got {}".format(order))
while torch.abs((s - t_0)).mean() > t_err:
t = ns.inverse_lambda(lambda_s + h)
x_lower, lower_noise_kwargs = lower_update(x, s, t)
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
delta = torch.max(torch.ones_like(x).to(
x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
def norm_fn(v): return torch.sqrt(torch.square(
v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
E = norm_fn((x_higher - x_lower) / delta).max()
if torch.all(E <= 1.):
x = x_higher
s = t
x_prev = x_lower
lambda_s = ns.marginal_lambda(s)
h = torch.min(theta * h * torch.float_power(E, -
1. / order).float(), lambda_0 - lambda_s)
nfe += order
print('adaptive solver nfe', nfe)
return x
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
atol=0.0078, rtol=0.05,
):
"""
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
=====================================================
We support the following algorithms for both noise prediction model and data prediction model:
- 'singlestep':
Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
The total number of function evaluations (NFE) == `steps`.
Given a fixed NFE == `steps`, the sampling procedure is:
- If `order` == 1:
- Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
- If `order` == 2:
- Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
- If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
- If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
- If `order` == 3:
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
- If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
- If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
- If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
- 'multistep':
Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
We initialize the first `order` values by lower order multistep solvers.
Given a fixed NFE == `steps`, the sampling procedure is:
Denote K = steps.
- If `order` == 1:
- We use K steps of DPM-Solver-1 (i.e. DDIM).
- If `order` == 2:
- We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
- If `order` == 3:
- We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
- 'singlestep_fixed':
Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
- 'adaptive':
Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
(NFE) and the sample quality.
- If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
- If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
=====================================================
Some advices for choosing the algorithm:
- For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
e.g.
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
skip_type='time_uniform', method='singlestep')
- For **guided sampling with large guidance scale** by DPMs:
Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
e.g.
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
skip_type='time_uniform', method='multistep')
We support three types of `skip_type`:
- 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
- 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
- 'time_quadratic': quadratic time for the time steps.
=====================================================
Args:
x: A pytorch tensor. The initial value at time `t_start`
e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
steps: A `int`. The total number of function evaluations (NFE).
t_start: A `float`. The starting time of the sampling.
If `T` is None, we use self.noise_schedule.T (default is 1.0).
t_end: A `float`. The ending time of the sampling.
If `t_end` is None, we use 1. / self.noise_schedule.total_N.
e.g. if total_N == 1000, we have `t_end` == 1e-3.
For discrete-time DPMs:
- We recommend `t_end` == 1. / self.noise_schedule.total_N.
For continuous-time DPMs:
- We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
order: A `int`. The order of DPM-Solver.
skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
for diffusion models sampling by diffusion SDEs for low-resolutional images
(such as CIFAR-10). However, we observed that such trick does not matter for
high-resolutional images. As it needs an additional NFE, we do not recommend
it for high-resolutional images.
lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
Only valid for `method=multistep` and `steps < 15`. We empirically find that
this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
(especially for steps <= 10). So we recommend to set it to be `True`.
solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
Returns:
x_end: A pytorch tensor. The approximated solution at time `t_end`.
"""
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
if method == 'adaptive':
with torch.no_grad():
x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
solver_type=solver_type)
elif method == 'multistep':
assert steps >= order
time_steps = self.get_time_steps(
skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
assert time_steps.shape[0] - 1 == steps
with torch.no_grad():
vec_t = time_steps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
t_prev_list = [vec_t]
# Init the first `order` values by lower order multistep DPM-Solver.
for init_order in tqdm(range(1, order), desc="DPM init order"):
vec_t = time_steps[init_order].expand(x.shape[0])
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
solver_type=solver_type)
model_prev_list.append(self.model_fn(x, vec_t))
t_prev_list.append(vec_t)
# Compute the remaining values by `order`-th order multistep DPM-Solver.
for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
vec_t = time_steps[step].expand(x.shape[0])
if lower_order_final and steps < 15:
step_order = min(order, steps + 1 - step)
else:
step_order = order
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
solver_type=solver_type)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = vec_t
# We do not need to evaluate the final model value.
if step < steps:
model_prev_list[-1] = self.model_fn(x, vec_t)
elif method in ['singlestep', 'singlestep_fixed']:
if method == 'singlestep':
time_steps_outer, orders = self.get_orders_and_time_steps_for_singlestep_solver(steps=steps, order=order,
skip_type=skip_type,
t_T=t_T, t_0=t_0,
device=device)
elif method == 'singlestep_fixed':
K = steps // order
orders = [order, ] * K
time_steps_outer = self.get_time_steps(
skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
for i, order in enumerate(orders):
t_T_inner, t_0_inner = time_steps_outer[i], time_steps_outer[i + 1]
time_steps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
N=order, device=device)
lambda_inner = self.noise_schedule.marginal_lambda(
time_steps_inner)
vec_s, vec_t = t_T_inner.tile(
x.shape[0]), t_0_inner.tile(x.shape[0])
h = lambda_inner[-1] - lambda_inner[0]
r1 = None if order <= 1 else (
lambda_inner[1] - lambda_inner[0]) / h
r2 = None if order <= 2 else (
lambda_inner[2] - lambda_inner[0]) / h
x = self.singlestep_dpm_solver_update(
x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
if denoise_to_zero:
x = self.denoise_to_zero_fn(
x, torch.ones((x.shape[0],)).to(device) * t_0)
return x
def interpolate_fn(x, xp, yp):
"""
A piecewise linear function y = f(x), using xp and yp as keypoints.
We implement f(x) in a differentiable way (i.e. applicable for autograd).
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
Args:
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
yp: PyTorch tensor with shape [C, K].
Returns:
The function values f(x), with shape [N, C].
"""
N, K = x.shape[0], xp.shape[1]
all_x = torch.cat(
[x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
x_idx = torch.argmin(x_indices, dim=2)
cand_start_idx = x_idx - 1
start_idx = torch.where(
torch.eq(x_idx, 0),
torch.tensor(1, device=x.device),
torch.where(
torch.eq(x_idx, K), torch.tensor(
K - 2, device=x.device), cand_start_idx,
),
)
end_idx = torch.where(torch.eq(start_idx, cand_start_idx),
start_idx + 2, start_idx + 1)
start_x = torch.gather(sorted_all_x, dim=2,
index=start_idx.unsqueeze(2)).squeeze(2)
end_x = torch.gather(sorted_all_x, dim=2,
index=end_idx.unsqueeze(2)).squeeze(2)
start_idx2 = torch.where(
torch.eq(x_idx, 0),
torch.tensor(0, device=x.device),
torch.where(
torch.eq(x_idx, K), torch.tensor(
K - 2, device=x.device), cand_start_idx,
),
)
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
start_y = torch.gather(y_positions_expanded, dim=2,
index=start_idx2.unsqueeze(2)).squeeze(2)
end_y = torch.gather(y_positions_expanded, dim=2, index=(
start_idx2 + 1).unsqueeze(2)).squeeze(2)
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
return cand
def expand_dims(v, dims):
"""
Expand the tensor `v` to the dim `dims`.
Args:
`v`: a PyTorch tensor with shape [N].
`dim`: a `int`.
Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
"""
return v[(...,) + (None,) * (dims - 1)]
"""SAMPLING ONLY."""
import torch
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
MODEL_TYPES = {"eps": "noise", "v": "v"}
class DPMSolverSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
def to_torch(x):
return x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
x_T=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
try:
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
except:
cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
if cbs != batch_size:
print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
print(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
# sampling
T, C, H, W = shape
size = (batch_size, T, C, H, W)
print(f"Data shape for DPM-Solver sampling is {size}, sampling steps {S}")
device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T
ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
ns,
model_type=MODEL_TYPES[self.model.parameterization],
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(
img,
steps=S,
skip_type="time_uniform",
method="multistep",
order=2,
lower_order_final=True,
)
return x.to(device), None
"""SAMPLING ONLY."""
import numpy as np
from tqdm import tqdm
import torch
from core.models.utils_diffusion import (
make_ddim_sampling_parameters,
make_ddim_time_steps,
)
from core.common import noise_like
class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_time_steps = model.num_time_steps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
if ddim_eta != 0:
raise ValueError("ddim_eta must be 0 for PLMS")
self.ddim_time_steps = make_ddim_time_steps(
ddim_discr_method=ddim_discretize,
num_ddim_time_steps=ddim_num_steps,
num_ddpm_time_steps=self.ddpm_num_time_steps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_time_steps
), "alphas have to be defined for each timestep"
def to_torch(x):
return x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_time_steps=self.ddim_time_steps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
print(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f"Data shape for PLMS sampling is {size}")
samples, intermediates = self.plms_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
time_steps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if time_steps is None:
time_steps = (
self.ddpm_num_time_steps
if ddim_use_original_steps
else self.ddim_time_steps
)
elif time_steps is not None and not ddim_use_original_steps:
subset_end = (
int(
min(time_steps / self.ddim_time_steps.shape[0], 1)
* self.ddim_time_steps.shape[0]
)
- 1
)
time_steps = self.ddim_time_steps[:subset_end]
intermediates = {"x_inter": [img], "pred_x0": [img]}
time_range = (
list(reversed(range(0, time_steps)))
if ddim_use_original_steps
else np.flip(time_steps)
)
total_steps = time_steps if ddim_use_original_steps else time_steps.shape[0]
print(f"Running PLMS Sampling with {total_steps} time_steps")
iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full(
(b,),
time_range[min(i + 1, len(time_range) - 1)],
device=device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts)
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_plms(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)
intermediates["pred_x0"].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_plms(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (
55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t
"""SAMPLING ONLY."""
import torch
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
class UniPCSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
def to_torch(x):
return x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
x_T=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
# sampling
T, C, H, W = shape
size = (batch_size, T, C, H, W)
device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T
ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
ns,
model_type="v",
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False)
x = uni_pc.sample(
img,
steps=S,
skip_type="time_uniform",
method="multistep",
order=2,
lower_order_final=True,
)
return x.to(device), None
import torch
import torch.nn.functional as F
import math
class NoiseScheduleVP:
def __init__(
self,
schedule="discrete",
betas=None,
alphas_cumprod=None,
continuous_beta_0=0.1,
continuous_beta_1=20.0,
):
"""Create a wrapper class for the forward SDE (VP type).
***
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
***
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
log_alpha_t = self.marginal_log_mean_coeff(t)
sigma_t = self.marginal_std(t)
lambda_t = self.marginal_lambda(t)
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
t = self.inverse_lambda(lambda_t)
===============================================================
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
1. For discrete-time DPMs:
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
t_i = (i + 1) / N
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
Args:
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
**Important**: Please pay special attention for the args for `alphas_cumprod`:
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
alpha_{t_n} = \sqrt{\hat{alpha_n}},
and
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
2. For continuous-time DPMs:
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
schedule are the default settings in DDPM and improved-DDPM:
Args:
beta_min: A `float` number. The smallest beta for the linear schedule.
beta_max: A `float` number. The largest beta for the linear schedule.
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
T: A `float` number. The ending time of the forward process.
===============================================================
Args:
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
'linear' or 'cosine' for continuous-time DPMs.
Returns:
A wrapper object of the forward SDE (VP type).
===============================================================
Example:
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', betas=betas)
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
# For continuous-time DPMs (VPSDE), linear schedule:
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
"""
if schedule not in ["discrete", "linear", "cosine"]:
raise ValueError(
"Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
schedule
)
)
self.schedule = schedule
if schedule == "discrete":
if betas is not None:
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
else:
assert alphas_cumprod is not None
log_alphas = 0.5 * torch.log(alphas_cumprod)
self.total_N = len(log_alphas)
self.T = 1.0
self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape(
(1, -1)
)
self.log_alpha_array = log_alphas.reshape(
(
1,
-1,
)
)
else:
self.total_N = 1000
self.beta_0 = continuous_beta_0
self.beta_1 = continuous_beta_1
self.cosine_s = 0.008
self.cosine_beta_max = 999.0
self.cosine_t_max = (
math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
* 2.0
* (1.0 + self.cosine_s)
/ math.pi
- self.cosine_s
)
self.cosine_log_alpha_0 = math.log(
math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)
)
self.schedule = schedule
if schedule == "cosine":
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
self.T = 0.9946
else:
self.T = 1.0
def marginal_log_mean_coeff(self, t):
"""
Compute log(alpha_t) of a given continuous-time label t in [0, T].
"""
if self.schedule == "discrete":
return interpolate_fn(
t.reshape((-1, 1)),
self.t_array.to(t.device),
self.log_alpha_array.to(t.device),
).reshape((-1))
elif self.schedule == "linear":
return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
elif self.schedule == "cosine":
def log_alpha_fn(s):
return torch.log(
torch.cos(
(s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0
)
)
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
return log_alpha_t
def marginal_alpha(self, t):
"""
Compute alpha_t of a given continuous-time label t in [0, T].
"""
return torch.exp(self.marginal_log_mean_coeff(t))
def marginal_std(self, t):
"""
Compute sigma_t of a given continuous-time label t in [0, T].
"""
return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
def marginal_lambda(self, t):
"""
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
"""
log_mean_coeff = self.marginal_log_mean_coeff(t)
log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
return log_mean_coeff - log_std
def inverse_lambda(self, lamb):
"""
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
"""
if self.schedule == "linear":
tmp = (
2.0
* (self.beta_1 - self.beta_0)
* torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
)
Delta = self.beta_0**2 + tmp
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
elif self.schedule == "discrete":
log_alpha = -0.5 * torch.logaddexp(
torch.zeros((1,)).to(lamb.device), -2.0 * lamb
)
t = interpolate_fn(
log_alpha.reshape((-1, 1)),
torch.flip(self.log_alpha_array.to(lamb.device), [1]),
torch.flip(self.t_array.to(lamb.device), [1]),
)
return t.reshape((-1,))
else:
log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
def t_fn(log_alpha_t):
return (
torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0))
* 2.0
* (1.0 + self.cosine_s)
/ math.pi
- self.cosine_s
)
t = t_fn(log_alpha)
return t
def model_wrapper(
model,
noise_schedule,
model_type="noise",
model_kwargs={},
guidance_type="uncond",
condition=None,
unconditional_condition=None,
guidance_scale=1.0,
classifier_fn=None,
classifier_kwargs={},
):
"""Create a wrapper function for the noise prediction model.
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
We support four types of the diffusion model by setting `model_type`:
1. "noise": noise prediction model. (Trained by predicting noise).
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
3. "v": velocity prediction model. (Trained by predicting the velocity).
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
arXiv preprint arXiv:2202.00512 (2022).
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
arXiv preprint arXiv:2210.02303 (2022).
4. "score": marginal score function. (Trained by denoising score matching).
Note that the score function and the noise prediction model follows a simple relationship:
```
noise(x_t, t) = -sigma_t * score(x_t, t)
```
We support three types of guided sampling by DPMs by setting `guidance_type`:
1. "uncond": unconditional sampling by DPMs.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
The input `classifier_fn` has the following format:
``
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
``
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
The input `model` has the following format:
``
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
``
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
arXiv preprint arXiv:2207.12598 (2022).
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
or continuous-time labels (i.e. epsilon to T).
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
``
def model_fn(x, t_continuous) -> noise:
t_input = get_model_input_time(t_continuous)
return noise_pred(model, x, t_input, **model_kwargs)
``
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
===============================================================
Args:
model: A diffusion model with the corresponding format described above.
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
model_type: A `str`. The parameterization type of the diffusion model.
"noise" or "x_start" or "v" or "score".
model_kwargs: A `dict`. A dict for the other inputs of the model function.
guidance_type: A `str`. The type of the guidance for sampling.
"uncond" or "classifier" or "classifier-free".
condition: A pytorch tensor. The condition for the guided sampling.
Only used for "classifier" or "classifier-free" guidance type.
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
Only used for "classifier-free" guidance type.
guidance_scale: A `float`. The scale for the guided sampling.
classifier_fn: A classifier function. Only used for the classifier guidance.
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
Returns:
A noise prediction model that accepts the noised data and the continuous time as the inputs.
"""
def get_model_input_time(t_continuous):
"""
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
For continuous-time DPMs, we just use `t_continuous`.
"""
if noise_schedule.schedule == "discrete":
return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
else:
return t_continuous
def noise_pred_fn(x, t_continuous, cond=None):
if t_continuous.reshape((-1,)).shape[0] == 1:
t_continuous = t_continuous.expand((x.shape[0]))
t_input = get_model_input_time(t_continuous)
if cond is None:
output = model(x, t_input, None, **model_kwargs)
else:
output = model(x, t_input, cond, **model_kwargs)
if isinstance(output, tuple):
output = output[0]
if model_type == "noise":
return output
elif model_type == "x_start":
alpha_t, sigma_t = noise_schedule.marginal_alpha(
t_continuous
), noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(
sigma_t, dims
)
elif model_type == "v":
alpha_t, sigma_t = noise_schedule.marginal_alpha(
t_continuous
), noise_schedule.marginal_std(t_continuous)
dims = x.dim()
print("alpha_t.shape", alpha_t.shape)
print("sigma_t.shape", sigma_t.shape)
print("dims", dims)
print("x.shape", x.shape)
# x: b, t, c, h, w
alpha_t = alpha_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sigma_t = sigma_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
print("alpha_t.shape", alpha_t.shape)
print("sigma_t.shape", sigma_t.shape)
print("output.shape", output.shape)
return alpha_t * output + sigma_t * x
elif model_type == "score":
sigma_t = noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return -expand_dims(sigma_t, dims) * output
def cond_grad_fn(x, t_input):
"""
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
"""
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
return torch.autograd.grad(log_prob.sum(), x_in)[0]
def model_fn(x, t_continuous):
"""
The noise predicition model function that is used for DPM-Solver.
"""
if t_continuous.reshape((-1,)).shape[0] == 1:
t_continuous = t_continuous.expand((x.shape[0]))
if guidance_type == "uncond":
return noise_pred_fn(x, t_continuous)
elif guidance_type == "classifier":
assert classifier_fn is not None
t_input = get_model_input_time(t_continuous)
cond_grad = cond_grad_fn(x, t_input)
sigma_t = noise_schedule.marginal_std(t_continuous)
noise = noise_pred_fn(x, t_continuous)
return (
noise
- guidance_scale
* expand_dims(sigma_t, dims=cond_grad.dim())
* cond_grad
)
elif guidance_type == "classifier-free":
if guidance_scale == 1.0 or unconditional_condition is None:
return noise_pred_fn(x, t_continuous, cond=condition)
else:
x_in = x
t_in = t_continuous
print("x_in.shape=", x_in.shape)
print("t_in.shape=", t_in.shape)
noise = noise_pred_fn(x_in, t_in, cond=condition)
noise_uncond = noise_pred_fn(x_in, t_in, cond=unconditional_condition)
return noise_uncond + guidance_scale * (noise - noise_uncond)
assert model_type in ["noise", "x_start", "v"]
assert guidance_type in ["uncond", "classifier", "classifier-free"]
return model_fn
class UniPC:
def __init__(
self,
model_fn,
noise_schedule,
predict_x0=True,
thresholding=False,
max_val=1.0,
variant="bh1",
):
"""Construct a UniPC.
We support both data_prediction and noise_prediction.
"""
self.model = model_fn
self.noise_schedule = noise_schedule
self.variant = variant
self.predict_x0 = predict_x0
self.thresholding = thresholding
self.max_val = max_val
def dynamic_thresholding_fn(self, x0, t=None):
"""
The dynamic thresholding method.
"""
dims = x0.dim()
p = self.dynamic_thresholding_ratio
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(
torch.maximum(
s, self.thresholding_max_val * torch.ones_like(s).to(s.device)
),
dims,
)
x0 = torch.clamp(x0, -s, s) / s
return x0
def noise_prediction_fn(self, x, t):
"""
Return the noise prediction model.
"""
return self.model(x, t)
def data_prediction_fn(self, x, t):
"""
Return the data prediction model (with thresholding).
"""
noise = self.noise_prediction_fn(x, t)
dims = x.dim()
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(
t
), self.noise_schedule.marginal_std(t)
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
if self.thresholding:
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(
torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims
)
x0 = torch.clamp(x0, -s, s) / s
return x0
def model_fn(self, x, t):
"""
Convert the model to the noise prediction model or the data prediction model.
"""
if self.predict_x0:
return self.data_prediction_fn(x, t)
else:
return self.noise_prediction_fn(x, t)
def get_time_steps(self, skip_type, t_T, t_0, N, device):
"""Compute the intermediate time steps for sampling."""
if skip_type == "logSNR":
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
logSNR_steps = torch.linspace(
lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1
).to(device)
return self.noise_schedule.inverse_lambda(logSNR_steps)
elif skip_type == "time_uniform":
return torch.linspace(t_T, t_0, N + 1).to(device)
elif skip_type == "time_quadratic":
t_order = 2
t = (
torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1)
.pow(t_order)
.to(device)
)
return t
else:
raise ValueError(
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(
skip_type
)
)
def get_orders_and_timesteps_for_singlestep_solver(
self, steps, order, skip_type, t_T, t_0, device
):
"""
Get the order of each step for sampling by the singlestep DPM-Solver.
"""
if order == 3:
K = steps // 3 + 1
if steps % 3 == 0:
orders = [
3,
] * (
K - 2
) + [2, 1]
elif steps % 3 == 1:
orders = [
3,
] * (
K - 1
) + [1]
else:
orders = [
3,
] * (
K - 1
) + [2]
elif order == 2:
if steps % 2 == 0:
K = steps // 2
orders = [
2,
] * K
else:
K = steps // 2 + 1
orders = [
2,
] * (
K - 1
) + [1]
elif order == 1:
K = steps
orders = [
1,
] * steps
else:
raise ValueError("'order' must be '1' or '2' or '3'.")
if skip_type == "logSNR":
# To reproduce the results in DPM-Solver paper
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
else:
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
torch.cumsum(
torch.tensor(
[
0,
]
+ orders
),
0,
).to(device)
]
return timesteps_outer, orders
def denoise_to_zero_fn(self, x, s):
"""
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
"""
return self.data_prediction_fn(x, s)
def multistep_uni_pc_update(
self, x, model_prev_list, t_prev_list, t, order, **kwargs
):
if len(t.shape) == 0:
t = t.view(-1)
if "bh" in self.variant:
return self.multistep_uni_pc_bh_update(
x, model_prev_list, t_prev_list, t, order, **kwargs
)
else:
assert self.variant == "vary_coeff"
return self.multistep_uni_pc_vary_update(
x, model_prev_list, t_prev_list, t, order, **kwargs
)
def multistep_uni_pc_vary_update(
self, x, model_prev_list, t_prev_list, t, order, use_corrector=True
):
print(
f"using unified predictor-corrector with order {order} (solver type: vary coeff)"
)
ns = self.noise_schedule
assert order <= len(model_prev_list)
# first compute rks
t_prev_0 = t_prev_list[-1]
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
lambda_t = ns.marginal_lambda(t)
model_prev_0 = model_prev_list[-1]
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
log_alpha_t = ns.marginal_log_mean_coeff(t)
alpha_t = torch.exp(log_alpha_t)
h = lambda_t - lambda_prev_0
rks = []
D1s = []
for i in range(1, order):
t_prev_i = t_prev_list[-(i + 1)]
model_prev_i = model_prev_list[-(i + 1)]
lambda_prev_i = ns.marginal_lambda(t_prev_i)
rk = (lambda_prev_i - lambda_prev_0) / h
rks.append(rk)
D1s.append((model_prev_i - model_prev_0) / rk)
rks.append(1.0)
rks = torch.tensor(rks, device=x.device)
K = len(rks)
# build C matrix
C = []
col = torch.ones_like(rks)
for k in range(1, K + 1):
C.append(col)
col = col * rks / (k + 1)
C = torch.stack(C, dim=1)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
C_inv_p = torch.linalg.inv(C[:-1, :-1])
A_p = C_inv_p
if use_corrector:
print("using corrector")
C_inv = torch.linalg.inv(C)
A_c = C_inv
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh)
h_phi_ks = []
factorial_k = 1
h_phi_k = h_phi_1
for k in range(1, K + 2):
h_phi_ks.append(h_phi_k)
h_phi_k = h_phi_k / hh - 1 / factorial_k
factorial_k *= k + 1
model_t = None
if self.predict_x0:
x_t_ = sigma_t / sigma_prev_0 * x - alpha_t * h_phi_1 * model_prev_0
# now predictor
x_t = x_t_
if len(D1s) > 0:
# compute the residuals for predictor
for k in range(K - 1):
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum(
"bktchw,k->btchw", D1s, A_p[k]
)
# now corrector
if use_corrector:
model_t = self.model_fn(x_t, t)
D1_t = model_t - model_prev_0
x_t = x_t_
k = 0
for k in range(K - 1):
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum(
"bktchw,k->btchw", D1s, A_c[k][:-1]
)
x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
else:
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
t_prev_0
), ns.marginal_log_mean_coeff(t)
x_t_ = (torch.exp(log_alpha_t - log_alpha_prev_0)) * x - (
sigma_t * h_phi_1
) * model_prev_0
# now predictor
x_t = x_t_
if len(D1s) > 0:
# compute the residuals for predictor
for k in range(K - 1):
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum(
"bktchw,k->btchw", D1s, A_p[k]
)
# now corrector
if use_corrector:
model_t = self.model_fn(x_t, t)
D1_t = model_t - model_prev_0
x_t = x_t_
k = 0
for k in range(K - 1):
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum(
"bktchw,k->btchw", D1s, A_c[k][:-1]
)
x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
return x_t, model_t
def multistep_uni_pc_bh_update(
self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True
):
print(
f"using unified predictor-corrector with order {order} (solver type: B(h))"
)
ns = self.noise_schedule
assert order <= len(model_prev_list)
dims = x.dim()
# first compute rks
t_prev_0 = t_prev_list[-1]
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
lambda_t = ns.marginal_lambda(t)
model_prev_0 = model_prev_list[-1]
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
t_prev_0
), ns.marginal_log_mean_coeff(t)
alpha_t = torch.exp(log_alpha_t)
h = lambda_t - lambda_prev_0
rks = []
D1s = []
for i in range(1, order):
t_prev_i = t_prev_list[-(i + 1)]
model_prev_i = model_prev_list[-(i + 1)]
lambda_prev_i = ns.marginal_lambda(t_prev_i)
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
rks.append(rk)
D1s.append((model_prev_i - model_prev_0) / rk)
rks.append(1.0)
rks = torch.tensor(rks, device=x.device)
R = []
b = []
hh = -h[0] if self.predict_x0 else h[0]
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.variant == "bh1":
B_h = hh
elif self.variant == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=x.device)
# now predictor
use_predictor = len(D1s) > 0 and x_t is None
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
if x_t is None:
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], device=b.device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
else:
D1s = None
if use_corrector:
print("using corrector")
# for order 1, we use a simplified version
if order == 1:
rhos_c = torch.tensor([0.5], device=b.device)
else:
rhos_c = torch.linalg.solve(R, b)
model_t = None
if self.predict_x0:
x_t_ = (
expand_dims(sigma_t / sigma_prev_0, dims) * x
- expand_dims(alpha_t * h_phi_1, dims) * model_prev_0
)
if x_t is None:
if use_predictor:
pred_res = torch.einsum("k,bktchw->btchw", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
if use_corrector:
model_t = self.model_fn(x_t, t)
if D1s is not None:
corr_res = torch.einsum("k,bktchw->btchw", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - model_prev_0
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (
corr_res + rhos_c[-1] * D1_t
)
else:
x_t_ = (
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
)
if x_t is None:
if use_predictor:
pred_res = torch.einsum("k,bktchw->btchw", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
if use_corrector:
model_t = self.model_fn(x_t, t)
if D1s is not None:
corr_res = torch.einsum("k,bktchw->btchw", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - model_prev_0
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (
corr_res + rhos_c[-1] * D1_t
)
return x_t, model_t
def sample(
self,
x,
steps=20,
t_start=None,
t_end=None,
order=3,
skip_type="time_uniform",
method="singlestep",
lower_order_final=True,
denoise_to_zero=False,
solver_type="dpm_solver",
atol=0.0078,
rtol=0.05,
corrector=False,
):
t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
if method == "multistep":
assert steps >= order
timesteps = self.get_time_steps(
skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device
)
assert timesteps.shape[0] - 1 == steps
with torch.no_grad():
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
t_prev_list = [vec_t]
# Init the first `order` values by lower order multistep DPM-Solver.
for init_order in range(1, order):
vec_t = timesteps[init_order].expand(x.shape[0])
x, model_x = self.multistep_uni_pc_update(
x,
model_prev_list,
t_prev_list,
vec_t,
init_order,
use_corrector=True,
)
if model_x is None:
model_x = self.model_fn(x, vec_t)
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
for step in range(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0])
print(f"Current step={step}; vec_t={vec_t}.")
if lower_order_final:
step_order = min(order, steps + 1 - step)
else:
step_order = order
print("this step order:", step_order)
if step == steps:
print("do not run corrector at the last step")
use_corrector = False
else:
use_corrector = True
x, model_x = self.multistep_uni_pc_update(
x,
model_prev_list,
t_prev_list,
vec_t,
step_order,
use_corrector=use_corrector,
)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = vec_t
# We do not need to evaluate the final model value.
if step < steps:
if model_x is None:
model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x
else:
raise NotImplementedError()
if denoise_to_zero:
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
return x
#############################################################
# other utility functions
#############################################################
def interpolate_fn(x, xp, yp):
"""
A piecewise linear function y = f(x), using xp and yp as keypoints.
We implement f(x) in a differentiable way (i.e. applicable for autograd).
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
Args:
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
yp: PyTorch tensor with shape [C, K].
Returns:
The function values f(x), with shape [N, C].
"""
N, K = x.shape[0], xp.shape[1]
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
x_idx = torch.argmin(x_indices, dim=2)
cand_start_idx = x_idx - 1
start_idx = torch.where(
torch.eq(x_idx, 0),
torch.tensor(1, device=x.device),
torch.where(
torch.eq(x_idx, K),
torch.tensor(K - 2, device=x.device),
cand_start_idx,
),
)
end_idx = torch.where(
torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1
)
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
start_idx2 = torch.where(
torch.eq(x_idx, 0),
torch.tensor(0, device=x.device),
torch.where(
torch.eq(x_idx, K),
torch.tensor(K - 2, device=x.device),
cand_start_idx,
),
)
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
start_y = torch.gather(
y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)
).squeeze(2)
end_y = torch.gather(
y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)
).squeeze(2)
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
return cand
def expand_dims(v, dims):
"""
Expand the tensor `v` to the dim `dims`.
Args:
`v`: a PyTorch tensor with shape [N].
`dim`: a `int`.
Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
"""
return v[(...,) + (None,) * (dims - 1)]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment