"server/text_generation_server/layers/moe/unquantized.py" did not exist on "86984e3236ef4771d964f6ac36b226717845f561"
Commit 30af93f2 authored by chenpangpang's avatar chenpangpang
Browse files

feat: gpu初始提交

parent 68e98ab8
Pipeline #2159 canceled with stages
import copy
import numpy as np
import torch
from scipy.spatial.transform import Rotation as R
def get_opencv_from_blender(matrix_world, fov, image_size):
# convert matrix_world to opencv format extrinsics
opencv_world_to_cam = matrix_world.inverse()
opencv_world_to_cam[1, :] *= -1
opencv_world_to_cam[2, :] *= -1
R, T = opencv_world_to_cam[:3, :3], opencv_world_to_cam[:3, 3]
R, T = R.unsqueeze(0), T.unsqueeze(0)
# convert fov to opencv format intrinsics
focal = 1 / np.tan(fov / 2)
intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
opencv_cam_matrix = torch.from_numpy(intrinsics).unsqueeze(0).float()
opencv_cam_matrix[:, :2, -1] += torch.tensor([image_size / 2, image_size / 2])
opencv_cam_matrix[:, [0, 1], [0, 1]] *= image_size / 2
return R, T, opencv_cam_matrix
def cartesian_to_spherical(xyz):
xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2
z = np.sqrt(xy + xyz[:, 2] ** 2)
# for elevation angle defined from z-axis down
theta = np.arctan2(np.sqrt(xy), xyz[:, 2])
azimuth = np.arctan2(xyz[:, 1], xyz[:, 0])
return np.stack([theta, azimuth, z], axis=-1)
def spherical_to_cartesian(spherical_coords):
# convert from spherical to cartesian coordinates
theta, azimuth, radius = spherical_coords.T
x = radius * np.sin(theta) * np.cos(azimuth)
y = radius * np.sin(theta) * np.sin(azimuth)
z = radius * np.cos(theta)
return np.stack([x, y, z], axis=-1)
def look_at(eye, center, up):
# Create a normalized direction vector from eye to center
f = np.array(center) - np.array(eye)
f /= np.linalg.norm(f)
# Create a normalized right vector
up_norm = np.array(up) / np.linalg.norm(up)
s = np.cross(f, up_norm)
s /= np.linalg.norm(s)
# Recompute the up vector
u = np.cross(s, f)
# Create rotation matrix R
R = np.array([[s[0], s[1], s[2]], [u[0], u[1], u[2]], [-f[0], -f[1], -f[2]]])
# Create translation vector T
T = -np.dot(R, np.array(eye))
return R, T
def get_blender_from_spherical(elevation, azimuth):
"""Generates blender camera from spherical coordinates."""
cartesian_coords = spherical_to_cartesian(np.array([[elevation, azimuth, 3.5]]))
# get camera rotation
center = np.array([0, 0, 0])
eye = cartesian_coords[0]
up = np.array([0, 0, 1])
R, T = look_at(eye, center, up)
R = R.T
T = -np.dot(R, T)
RT = np.concatenate([R, T.reshape(3, 1)], axis=-1)
blender_cam = torch.from_numpy(RT).float()
blender_cam = torch.cat([blender_cam, torch.tensor([[0, 0, 0, 1]])], dim=0)
print(blender_cam)
return blender_cam
def invert_pose(r, t):
r_inv = r.T
t_inv = -np.dot(r_inv, t)
return r_inv, t_inv
def transform_pose_sequence_to_relative(poses, as_z_up=False):
"""
poses: a sequence of 3*4 C2W camera pose matrices
as_z_up: output in z-up format. If False, the output is in y-up format
"""
r0, t0 = poses[0][:3, :3], poses[0][:3, 3]
# r0_inv, t0_inv = invert_pose(r0, t0)
r0_inv = r0.T
new_rt0 = np.hstack([np.eye(3, 3), np.zeros((3, 1))])
if as_z_up:
new_rt0 = c2w_y_up_to_z_up(new_rt0)
transformed_poses = [new_rt0]
for pose in poses[1:]:
r, t = pose[:3, :3], pose[:3, 3]
new_r = np.dot(r0_inv, r)
new_t = np.dot(r0_inv, t - t0)
new_rt = np.hstack([new_r, new_t[:, None]])
if as_z_up:
new_rt = c2w_y_up_to_z_up(new_rt)
transformed_poses.append(new_rt)
return transformed_poses
def c2w_y_up_to_z_up(c2w_3x4):
R_y_up_to_z_up = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
R = c2w_3x4[:, :3]
t = c2w_3x4[:, 3]
R_z_up = R_y_up_to_z_up @ R
t_z_up = R_y_up_to_z_up @ t
T_z_up = np.hstack((R_z_up, t_z_up.reshape(3, 1)))
return T_z_up
def transform_pose_sequence_to_relative_w2c(poses):
new_rt_list = []
first_frame_rt = copy.deepcopy(poses[0])
first_frame_r_inv = first_frame_rt[:, :3].T
first_frame_t = first_frame_rt[:, -1]
for rt in poses:
rt[:, :3] = np.matmul(rt[:, :3], first_frame_r_inv)
rt[:, -1] = rt[:, -1] - np.matmul(rt[:, :3], first_frame_t)
new_rt_list.append(copy.deepcopy(rt))
return new_rt_list
def transform_pose_sequence_to_relative_c2w(poses):
first_frame_rt = poses[0]
first_frame_r_inv = first_frame_rt[:, :3].T
first_frame_t = first_frame_rt[:, -1]
rotations = poses[:, :, :3]
translations = poses[:, :, 3]
# Compute new rotations and translations in batch
new_rotations = torch.matmul(first_frame_r_inv, rotations)
new_translations = torch.matmul(
first_frame_r_inv, (translations - first_frame_t.unsqueeze(0)).unsqueeze(-1)
)
# Concatenate new rotations and translations
new_rt = torch.cat([new_rotations, new_translations], dim=-1)
return new_rt
def convert_w2c_between_c2w(poses):
rotations = poses[:, :, :3]
translations = poses[:, :, 3]
new_rotations = rotations.transpose(-1, -2)
new_translations = torch.matmul(-new_rotations, translations.unsqueeze(-1))
new_rt = torch.cat([new_rotations, new_translations], dim=-1)
return new_rt
def slerp(q1, q2, t):
"""
Performs spherical linear interpolation (SLERP) between two quaternions.
Args:
q1 (torch.Tensor): Start quaternion (4,).
q2 (torch.Tensor): End quaternion (4,).
t (float or torch.Tensor): Interpolation parameter in [0, 1].
Returns:
torch.Tensor: Interpolated quaternion (4,).
"""
q1 = q1 / torch.linalg.norm(q1) # Normalize q1
q2 = q2 / torch.linalg.norm(q2) # Normalize q2
dot = torch.dot(q1, q2)
# Ensure shortest path (flip q2 if needed)
if dot < 0.0:
q2 = -q2
dot = -dot
# Avoid numerical precision issues
dot = torch.clamp(dot, -1.0, 1.0)
theta = torch.acos(dot) # Angle between q1 and q2
if theta < 1e-6: # If very close, use linear interpolation
return (1 - t) * q1 + t * q2
sin_theta = torch.sin(theta)
return (torch.sin((1 - t) * theta) / sin_theta) * q1 + (
torch.sin(t * theta) / sin_theta
) * q2
def interpolate_camera_poses(c2w: torch.Tensor, factor: int) -> torch.Tensor:
"""
Interpolates a sequence of camera c2w poses to N times the length of the original sequence.
Args:
c2w (torch.Tensor): Input camera poses of shape (N, 3, 4).
factor (int): The upsampling factor (e.g., 2 for doubling the length).
Returns:
torch.Tensor: Interpolated camera poses of shape (N * factor, 3, 4).
"""
assert c2w.ndim == 3 and c2w.shape[1:] == (
3,
4,
), "Input tensor must have shape (N, 3, 4)."
assert factor > 1, "Upsampling factor must be greater than 1."
N = c2w.shape[0]
new_length = N * factor
# Extract rotations (R) and translations (T)
rotations = c2w[:, :3, :3] # Shape (N, 3, 3)
translations = c2w[:, :3, 3] # Shape (N, 3)
# Convert rotations to quaternions for interpolation
quaternions = torch.tensor(
R.from_matrix(rotations.numpy()).as_quat()
) # Shape (N, 4)
# Initialize interpolated quaternions and translations
interpolated_quats = []
interpolated_translations = []
# Perform interpolation
for i in range(N - 1):
# Start and end quaternions and translations for this segment
q1, q2 = quaternions[i], quaternions[i + 1]
t1, t2 = translations[i], translations[i + 1]
# Time steps for interpolation within this segment
t_values = torch.linspace(0, 1, factor, dtype=torch.float32)
# Interpolate quaternions using SLERP
for t in t_values:
interpolated_quats.append(slerp(q1, q2, t))
# Interpolate translations linearly
interp_t = t1 * (1 - t_values[:, None]) + t2 * t_values[:, None]
interpolated_translations.append(interp_t)
interpolated_quats.append(quaternions[0])
interpolated_translations.append(translations[0].unsqueeze(0))
# Add the last pose (end of sequence)
interpolated_quats.append(quaternions[-1])
interpolated_translations.append(translations[-1].unsqueeze(0)) # Add as 2D tensor
# Combine interpolated results
interpolated_quats = torch.stack(interpolated_quats, dim=0) # Shape (new_length, 4)
interpolated_translations = torch.cat(
interpolated_translations, dim=0
) # Shape (new_length, 3)
# Convert quaternions back to rotation matrices
interpolated_rotations = torch.tensor(
R.from_quat(interpolated_quats.numpy()).as_matrix()
) # Shape (new_length, 3, 3)
# Form final c2w matrix
interpolated_c2w = torch.zeros((new_length, 3, 4), dtype=torch.float32)
interpolated_c2w[:, :3, :3] = interpolated_rotations
interpolated_c2w[:, :3, 3] = interpolated_translations
return interpolated_c2w
import PIL
import numpy as np
import torch
from PIL import Image
from .camera_pose_utils import (
convert_w2c_between_c2w,
transform_pose_sequence_to_relative_c2w,
)
def get_ray_embeddings(
poses, size_h=256, size_w=256, fov_xy_list=None, focal_xy_list=None
):
"""
poses: sequence of cameras poses (y-up format)
"""
use_focal = False
if fov_xy_list is None or fov_xy_list[0] is None or fov_xy_list[0][0] is None:
assert focal_xy_list is not None
use_focal = True
rays_embeddings = []
for i in range(poses.shape[0]):
cur_pose = poses[i]
if use_focal:
rays_o, rays_d = get_rays(
# [h, w, 3]
cur_pose,
size_h,
size_w,
focal_xy=focal_xy_list[i],
)
else:
rays_o, rays_d = get_rays(
cur_pose, size_h, size_w, fov_xy=fov_xy_list[i]
) # [h, w, 3]
rays_plucker = torch.cat(
[torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1
) # [h, w, 6]
rays_embeddings.append(rays_plucker)
rays_embeddings = (
torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous()
) # [V, 6, h, w]
return rays_embeddings
def get_rays(pose, h, w, fov_xy=None, focal_xy=None, opengl=True):
x, y = torch.meshgrid(
torch.arange(w, device=pose.device),
torch.arange(h, device=pose.device),
indexing="xy",
)
x = x.flatten()
y = y.flatten()
cx = w * 0.5
cy = h * 0.5
# print("fov_xy=", fov_xy)
# print("focal_xy=", focal_xy)
if focal_xy is None:
assert fov_xy is not None, "fov_x/y and focal_x/y cannot both be None."
focal_x = w * 0.5 / np.tan(0.5 * np.deg2rad(fov_xy[0]))
focal_y = h * 0.5 / np.tan(0.5 * np.deg2rad(fov_xy[1]))
else:
assert (
len(focal_xy) == 2
), "focal_xy should be a list-like object containing only two elements (focal length in x and y direction)."
focal_x = w * focal_xy[0]
focal_y = h * focal_xy[1]
camera_dirs = torch.nn.functional.pad(
torch.stack(
[
(x - cx + 0.5) / focal_x,
(y - cy + 0.5) / focal_y * (-1.0 if opengl else 1.0),
],
dim=-1,
),
(0, 1),
value=(-1.0 if opengl else 1.0),
) # [hw, 3]
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
rays_o = rays_o.view(h, w, 3)
rays_d = safe_normalize(rays_d).view(h, w, 3)
return rays_o, rays_d
def safe_normalize(x, eps=1e-20):
return x / length(x, eps)
def length(x, eps=1e-20):
if isinstance(x, np.ndarray):
return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
else:
return torch.sqrt(torch.clamp(dot(x, x), min=eps))
def dot(x, y):
if isinstance(x, np.ndarray):
return np.sum(x * y, -1, keepdims=True)
else:
return torch.sum(x * y, -1, keepdim=True)
def extend_list_by_repeating(original_list, target_length, repeat_idx, at_front):
if not original_list:
raise ValueError("The original list cannot be empty.")
extended_list = []
original_length = len(original_list)
for i in range(target_length - original_length):
extended_list.append(original_list[repeat_idx])
if at_front:
extended_list.extend(original_list)
return extended_list
else:
original_list.extend(extended_list)
return original_list
def select_evenly_spaced_elements(arr, x):
if x <= 0 or len(arr) == 0:
return []
# Calculate step size as the ratio of length of the list and x
step = len(arr) / x
# Pick elements at indices that are multiples of step (round them to nearest integer)
selected_elements = [arr[round(i * step)] for i in range(x)]
return selected_elements
def convert_co3d_annotation_to_opengl_pose_and_intrinsics(frame_annotation):
p = frame_annotation.viewpoint.principal_point
f = frame_annotation.viewpoint.focal_length
h, w = frame_annotation.image.size
K = np.eye(3)
s = (min(h, w) - 1) / 2
if frame_annotation.viewpoint.intrinsics_format == "ndc_norm_image_bounds":
K[0, 0] = f[0] * (w - 1) / 2
K[1, 1] = f[1] * (h - 1) / 2
elif frame_annotation.viewpoint.intrinsics_format == "ndc_isotropic":
K[0, 0] = f[0] * s / 2
K[1, 1] = f[1] * s / 2
else:
assert (
False
), f"Invalid intrinsics_format: {frame_annotation.viewpoint.intrinsics_format}"
K[0, 2] = -p[0] * s + (w - 1) / 2
K[1, 2] = -p[1] * s + (h - 1) / 2
R = np.array(frame_annotation.viewpoint.R).T # note the transpose here
T = np.array(frame_annotation.viewpoint.T)
pose = np.concatenate([R, T[:, None]], 1)
# Need to be converted into OpenGL format. Flip the direction of x, z axis
pose = np.diag([-1, 1, -1]).astype(np.float32) @ pose
return pose, K
def normalize_w2c_camera_pose_sequence(
target_camera_poses,
condition_camera_poses=None,
output_c2w=False,
translation_norm_mode="div_by_max",
):
"""
Normalize camera pose sequence so that the first frame is identity rotation and zero translation,
and the translation scale is normalized by the farest point from the first frame (to one).
:param target_camera_poses: W2C poses tensor in [N, 3, 4]
:param condition_camera_poses: W2C poses tensor in [N, 3, 4]
:return: Tuple(Tensor, Tensor), the normalized `target_camera_poses` and `condition_camera_poses`
"""
# Normalize at w2c, all poses should be in w2c in UnifiedFrame
num_target_views = target_camera_poses.size(0)
if condition_camera_poses is not None:
all_poses = torch.concat([target_camera_poses, condition_camera_poses], dim=0)
else:
all_poses = target_camera_poses
# Convert W2C to C2W
normalized_poses = transform_pose_sequence_to_relative_c2w(
convert_w2c_between_c2w(all_poses)
)
# Here normalized_poses is C2W
if not output_c2w:
# Convert from C2W back to W2C if output_c2w is False.
normalized_poses = convert_w2c_between_c2w(normalized_poses)
t_norms = torch.linalg.norm(normalized_poses[:, :, 3], ord=2, dim=-1)
# print("t_norms=", t_norms)
largest_t_norm = torch.max(t_norms)
# print("largest_t_norm=", largest_t_norm)
# normalized_poses[:, :, 3] -= first_t.unsqueeze(0).repeat(normalized_poses.size(0), 1)
if translation_norm_mode == "div_by_max_plus_one":
# Always add a constant component to the translation norm
largest_t_norm = largest_t_norm + 1.0
elif translation_norm_mode == "div_by_max":
largest_t_norm = largest_t_norm
if largest_t_norm <= 0.05:
largest_t_norm = 0.05
elif translation_norm_mode == "disabled":
largest_t_norm = 1.0
else:
assert False, f"Invalid translation_norm_mode: {translation_norm_mode}."
normalized_poses[:, :, 3] /= largest_t_norm
target_camera_poses = normalized_poses[:num_target_views]
if condition_camera_poses is not None:
condition_camera_poses = normalized_poses[num_target_views:]
else:
condition_camera_poses = None
# print("After First condition:", condition_camera_poses[0])
# print("After First target:", target_camera_poses[0])
return target_camera_poses, condition_camera_poses
def central_crop_pil_image(_image, crop_size, use_central_padding=False):
if use_central_padding:
# Determine the new size
_w, _h = _image.size
new_size = max(_w, _h)
# Create a new image with white background
new_image = Image.new("RGB", (new_size, new_size), (255, 255, 255))
# Calculate the position to paste the original image
paste_position = ((new_size - _w) // 2, (new_size - _h) // 2)
# Paste the original image onto the new image
new_image.paste(_image, paste_position)
_image = new_image
# get the new size again if padded
_w, _h = _image.size
scale = crop_size / min(_h, _w)
# resize shortest side to crop_size
_w_out, _h_out = int(scale * _w), int(scale * _h)
_image = _image.resize(
(_w_out, _h_out),
resample=(
PIL.Image.Resampling.LANCZOS if scale < 1 else PIL.Image.Resampling.BICUBIC
),
)
# center crop
margin_w = (_image.size[0] - crop_size) // 2
margin_h = (_image.size[1] - crop_size) // 2
_image = _image.crop(
(margin_w, margin_h, margin_w + crop_size, margin_h + crop_size)
)
return _image
def crop_and_resize(
image: Image.Image, target_width: int, target_height: int
) -> Image.Image:
"""
Crops and resizes an image while preserving the aspect ratio.
Args:
image (Image.Image): Input PIL image to be cropped and resized.
target_width (int): Target width of the output image.
target_height (int): Target height of the output image.
Returns:
Image.Image: Cropped and resized image.
"""
# Original dimensions
original_width, original_height = image.size
original_aspect = original_width / original_height
target_aspect = target_width / target_height
# Calculate crop box to maintain aspect ratio
if original_aspect > target_aspect:
# Crop horizontally
new_width = int(original_height * target_aspect)
new_height = original_height
left = (original_width - new_width) / 2
top = 0
right = left + new_width
bottom = original_height
else:
# Crop vertically
new_width = original_width
new_height = int(original_width / target_aspect)
left = 0
top = (original_height - new_height) / 2
right = original_width
bottom = top + new_height
# Crop and resize
cropped_image = image.crop((left, top, right, bottom))
resized_image = cropped_image.resize((target_width, target_height), Image.LANCZOS)
return resized_image
def calculate_fov_after_resize(
fov_x: float,
fov_y: float,
original_width: int,
original_height: int,
target_width: int,
target_height: int,
) -> (float, float):
"""
Calculates the new field of view after cropping and resizing an image.
Args:
fov_x (float): Original field of view in the x-direction (horizontal).
fov_y (float): Original field of view in the y-direction (vertical).
original_width (int): Original width of the image.
original_height (int): Original height of the image.
target_width (int): Target width of the output image.
target_height (int): Target height of the output image.
Returns:
(float, float): New field of view (fov_x, fov_y) after cropping and resizing.
"""
original_aspect = original_width / original_height
target_aspect = target_width / target_height
if original_aspect > target_aspect:
# Crop horizontally
new_width = int(original_height * target_aspect)
new_fov_x = fov_x * (new_width / original_width)
new_fov_y = fov_y
else:
# Crop vertically
new_height = int(original_width / target_aspect)
new_fov_y = fov_y * (new_height / original_height)
new_fov_x = fov_x
return new_fov_x, new_fov_y
import copy
import random
from PIL import Image
import numpy as np
def create_relative(RT_list, K_1=4.7, dataset="syn"):
if dataset == "realestate":
scale_T = 1
RT_list = [RT.reshape(3, 4) for RT in RT_list]
elif dataset == "syn":
scale_T = (470 / K_1) / 7.5
"""
4.694746736956946052e+02 0.000000000000000000e+00 4.800000000000000000e+02
0.000000000000000000e+00 4.694746736956946052e+02 2.700000000000000000e+02
0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00
"""
elif dataset == "zero123":
scale_T = 0.5
else:
raise Exception("invalid dataset type")
# convert x y z to x -y -z
if dataset == "zero123":
flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
for i in range(len(RT_list)):
RT_list[i] = np.dot(flip_matrix, RT_list[i])
temp = []
first_frame_RT = copy.deepcopy(RT_list[0])
# first_frame_R_inv = np.linalg.inv(first_frame_RT[:,:3])
first_frame_R_inv = first_frame_RT[:, :3].T
first_frame_T = first_frame_RT[:, -1]
for RT in RT_list:
RT[:, :3] = np.dot(RT[:, :3], first_frame_R_inv)
RT[:, -1] = RT[:, -1] - np.dot(RT[:, :3], first_frame_T)
RT[:, -1] = RT[:, -1] * scale_T
temp.append(RT)
RT_list = temp
if dataset == "realestate":
RT_list = [RT.reshape(-1) for RT in RT_list]
return RT_list
def sigma_matrix2(sig_x, sig_y, theta):
"""Calculate the rotated sigma matrix (two dimensional matrix).
Args:
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
Returns:
ndarray: Rotated sigma matrix.
"""
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
u_matrix = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
def mesh_grid(kernel_size):
"""Generate the mesh grid, centering at zero.
Args:
kernel_size (int):
Returns:
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
xx (ndarray): with the shape (kernel_size, kernel_size)
yy (ndarray): with the shape (kernel_size, kernel_size)
"""
ax = np.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0)
xx, yy = np.meshgrid(ax, ax)
xy = np.hstack(
(
xx.reshape((kernel_size * kernel_size, 1)),
yy.reshape(kernel_size * kernel_size, 1),
)
).reshape(kernel_size, kernel_size, 2)
return xy, xx, yy
def pdf2(sigma_matrix, grid):
"""Calculate PDF of the bivariate Gaussian distribution.
Args:
sigma_matrix (ndarray): with the shape (2, 2)
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
kernel (ndarrray): un-normalized kernel.
"""
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
return kernel
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
isotropic (bool):
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
kernel = pdf2(sigma_matrix, grid)
kernel = kernel / np.sum(kernel)
return kernel
def rgba_to_rgb_with_bg(rgba_image, bg_color=(255, 255, 255)):
"""
Convert a PIL RGBA Image to an RGB Image with a white background.
Args:
rgba_image (Image): A PIL Image object in RGBA mode.
Returns:
Image: A PIL Image object in RGB mode with white background.
"""
# Ensure the image is in RGBA mode
# Ensure the image is in RGBA mode
if rgba_image.mode != "RGBA":
return rgba_image
# raise ValueError("The image must be in RGBA mode")
# Create a white background image
white_bg_rgb = Image.new("RGB", rgba_image.size, bg_color)
# Paste the RGBA image onto the white background using alpha channel as mask
white_bg_rgb.paste(
rgba_image, mask=rgba_image.split()[3]
) # 3 is the alpha channel index
return white_bg_rgb
def random_order_preserving_selection(items, num):
if num > len(items):
print("WARNING: Item list is shorter than `num` given.")
return items
selected_indices = sorted(random.sample(range(len(items)), num))
selected_items = [items[i] for i in selected_indices]
return selected_items
def pad_pil_image_to_square(image, fill_color=(255, 255, 255)):
"""
Pad an image to make it square with the given fill color.
Args:
image (PIL.Image): The original image.
fill_color (tuple): The color to use for padding (default is black).
Returns:
PIL.Image: A new image that is padded to be square.
"""
width, height = image.size
# Determine the new size, which will be the maximum of width or height
new_size = max(width, height)
# Create a new image with the new size and fill color
new_image = Image.new("RGB", (new_size, new_size), fill_color)
# Calculate the position to paste the original image onto the new image
# This calculation centers the original image in the new square canvas
left = (new_size - width) // 2
top = (new_size - height) // 2
# Paste the original image into the new image
new_image.paste(image, (left, top))
return new_image
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(
device=self.parameters.device
)
def sample(self, noise=None):
if noise is None:
noise = torch.randn(self.mean.shape)
x = self.mean + self.std * noise.to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
self.m_name2s_name = {}
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
"num_updates",
(
torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int)
),
)
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace(".", "")
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
from core.losses.contperceptual import LPIPSWithDiscriminator
\ No newline at end of file
import torch
import torch.nn as nn
from einops import rearrange
from taming.modules.losses.vqperceptual import *
class LPIPSWithDiscriminator(nn.Module):
def __init__(
self,
disc_start,
logvar_init=0.0,
kl_weight=1.0,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_loss="hinge",
max_bs=None,
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
self.kl_weight = kl_weight
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
self.max_bs = max_bs
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
inputs,
reconstructions,
posteriors,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split="train",
weights=None,
):
if inputs.dim() == 5:
inputs = rearrange(inputs, "b c t h w -> (b t) c h w")
if reconstructions.dim() == 5:
reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w")
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0:
if self.max_bs is not None and self.max_bs < inputs.shape[0]:
input_list = torch.split(inputs, self.max_bs, dim=0)
reconstruction_list = torch.split(reconstructions, self.max_bs, dim=0)
p_losses = [
self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
for inputs, reconstructions in zip(input_list, reconstruction_list)
]
p_loss = torch.cat(p_losses, dim=0)
else:
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
# now the GAN part
if optimizer_idx == 0:
# generator update
if cond is None:
assert not self.disc_conditional
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake)
if self.disc_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
loss = (
weighted_nll_loss
+ self.kl_weight * kl_loss
+ d_weight * disc_factor * g_loss
)
log = {
"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/logvar".format(split): self.logvar.detach(),
"{}/kl_loss".format(split): kl_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
else:
logits_real = self.discriminator(
torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean(),
}
return d_loss, log
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])
loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])
loss_real = (weights * loss_real).sum() / weights.sum()
loss_fake = (weights * loss_fake).sum() / weights.sum()
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
def measure_perplexity(predicted_indices, n_embed):
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
def l1(x, y):
return torch.abs(x - y)
def l2(x, y):
return torch.pow((x - y), 2)
class VQLPIPSWithDiscriminator(nn.Module):
def __init__(
self,
disc_start,
codebook_weight=1.0,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_ndf=64,
disc_loss="hinge",
n_classes=None,
perceptual_loss="lpips",
pixel_loss="l1",
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
assert perceptual_loss in ["lpips", "clips", "dists"]
assert pixel_loss in ["l1", "l2"]
self.codebook_weight = codebook_weight
self.pixel_weight = pixelloss_weight
if perceptual_loss == "lpips":
print(f"{self.__class__.__name__}: Running with LPIPS.")
self.perceptual_loss = LPIPS().eval()
else:
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
self.perceptual_weight = perceptual_weight
if pixel_loss == "l1":
self.pixel_loss = l1
else:
self.pixel_loss = l2
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
ndf=disc_ndf,
).apply(weights_init)
self.discriminator_iter_start = disc_start
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
self.n_classes = n_classes
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
codebook_loss,
inputs,
reconstructions,
optimizer_idx,
global_step,
last_layer=None,
cond=None,
split="train",
predicted_indices=None,
):
if not exists(codebook_loss):
codebook_loss = torch.tensor([0.0]).to(inputs.device)
# rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
else:
p_loss = torch.tensor([0.0])
nll_loss = rec_loss
# nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss = torch.mean(nll_loss)
# now the GAN part
if optimizer_idx == 0:
# generator update
if cond is None:
assert not self.disc_conditional
logits_fake = self.discriminator(reconstructions.contiguous())
else:
assert self.disc_conditional
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake)
try:
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
loss = (
nll_loss
+ d_weight * disc_factor * g_loss
+ self.codebook_weight * codebook_loss.mean()
)
log = {
"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/p_loss".format(split): p_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
if predicted_indices is not None:
assert self.n_classes is not None
with torch.no_grad():
perplexity, cluster_usage = measure_perplexity(
predicted_indices, self.n_classes
)
log[f"{split}/perplexity"] = perplexity
log[f"{split}/cluster_usage"] = cluster_usage
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
if cond is None:
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
else:
logits_real = self.discriminator(
torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = self.discriminator(
torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean(),
}
return d_loss, log
import os
import json
from contextlib import contextmanager
import torch
import numpy as np
from einops import rearrange
import torch.nn.functional as F
import torch.distributed as dist
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from taming.modules.vqvae.quantize import VectorQuantizer as VectorQuantizer
from core.modules.networks.ae_modules import Encoder, Decoder
from core.distributions import DiagonalGaussianDistribution
from utils.utils import instantiate_from_config
from utils.save_video import tensor2videogrids
from core.common import shape_to_str, gather_data
class AutoencoderKL(pl.LightningModule):
def __init__(
self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
test=False,
logdir=None,
input_dim=4,
test_args=None,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
self.input_dim = input_dim
self.test = test
self.test_args = test_args
self.logdir = logdir
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
if self.test:
self.init_test()
def init_test(
self,
):
self.test = True
save_dir = os.path.join(self.logdir, "test")
if "ckpt" in self.test_args:
ckpt_name = (
os.path.basename(self.test_args.ckpt).split(".ckpt")[0]
+ f"_epoch{self._cur_epoch}"
)
self.root = os.path.join(save_dir, ckpt_name)
else:
self.root = save_dir
if "test_subdir" in self.test_args:
self.root = os.path.join(save_dir, self.test_args.test_subdir)
self.root_zs = os.path.join(self.root, "zs")
self.root_dec = os.path.join(self.root, "reconstructions")
self.root_inputs = os.path.join(self.root, "inputs")
os.makedirs(self.root, exist_ok=True)
if self.test_args.save_z:
os.makedirs(self.root_zs, exist_ok=True)
if self.test_args.save_reconstruction:
os.makedirs(self.root_dec, exist_ok=True)
if self.test_args.save_input:
os.makedirs(self.root_inputs, exist_ok=True)
assert self.test_args is not None
self.test_maximum = getattr(
self.test_args, "test_maximum", None
) # 1500 # 12000/8
self.count = 0
self.eval_metrics = {}
self.decodes = []
self.save_decode_samples = 2048
if getattr(self.test_args, "cal_metrics", False):
self.EvalLpips = EvalLpips()
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")
try:
self._cur_epoch = sd["epoch"]
sd = sd["state_dict"]
except:
self._cur_epoch = "null"
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
# self.load_state_dict(sd, strict=True)
print(f"Restored from {path}")
def encode(self, x, **kwargs):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z, **kwargs):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
# if len(x.shape) == 3:
# x = x[..., None]
# if x.dim() == 4:
# x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
if x.dim() == 5 and self.input_dim == 4:
b, c, t, h, w = x.shape
self.b = b
self.t = t
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"discloss",
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val",
)
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val",
)
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def test_step(self, batch, batch_idx):
# save z, dec
inputs = self.get_input(batch, self.image_key)
# forward
sample_posterior = True
posterior = self.encode(inputs)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
# logs
if self.test_args.save_z:
torch.save(
z,
os.path.join(
self.root_zs,
f"zs_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(z)}.pt",
),
)
if self.test_args.save_reconstruction:
tensor2videogrids(
dec,
self.root_dec,
f"reconstructions_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(z)}.mp4",
fps=10,
)
if self.test_args.save_input:
tensor2videogrids(
inputs,
self.root_inputs,
f"inputs_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(z)}.mp4",
fps=10,
)
if "save_z" in self.test_args and self.test_args.save_z:
dec_np = (dec.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) + 1) / 2 * 255
dec_np = dec_np.astype(np.uint8)
self.root_dec_np = os.path.join(self.root, "reconstructions_np")
os.makedirs(self.root_dec_np, exist_ok=True)
np.savez(
os.path.join(
self.root_dec_np,
f"reconstructions_batch{batch_idx}_rank{self.global_rank}_shape{shape_to_str(dec_np)}.npz",
),
dec_np,
)
self.count += z.shape[0]
# misc
self.log("batch_idx", batch_idx, prog_bar=True)
self.log_dict(self.eval_metrics, prog_bar=True, logger=True)
torch.cuda.empty_cache()
if self.test_maximum is not None:
if self.count > self.test_maximum:
import sys
sys.exit()
else:
prog = self.count / self.test_maximum * 100
print(f"Test progress: {prog:.2f}% [{self.count}/{self.test_maximum}]")
@rank_zero_only
def on_test_end(self):
if self.test_args.cal_metrics:
psnrs, ssims, ms_ssims, lpipses = [], [], [], []
n_batches = 0
n_samples = 0
overall = {}
for k, v in self.eval_metrics.items():
psnrs.append(v["psnr"])
ssims.append(v["ssim"])
lpipses.append(v["lpips"])
n_batches += 1
n_samples += v["n_samples"]
mean_psnr = sum(psnrs) / len(psnrs)
mean_ssim = sum(ssims) / len(ssims)
# overall['ms_ssim'] = min(ms_ssims)
mean_lpips = sum(lpipses) / len(lpipses)
overall = {
"psnr": mean_psnr,
"ssim": mean_ssim,
"lpips": mean_lpips,
"n_batches": n_batches,
"n_samples": n_samples,
}
overall_t = torch.tensor([mean_psnr, mean_ssim, mean_lpips])
# dump
for k, v in overall.items():
if isinstance(v, torch.Tensor):
overall[k] = float(v)
with open(
os.path.join(self.root, f"reconstruction_metrics.json"), "w"
) as f:
json.dump(overall, f)
f.close()
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr,
betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
)
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x
This diff is collapsed.
This diff is collapsed.
from .sampler import DPMSolverSampler
\ No newline at end of file
This diff is collapsed.
"""SAMPLING ONLY."""
import torch
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
MODEL_TYPES = {"eps": "noise", "v": "v"}
class DPMSolverSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
def to_torch(x):
return x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
x_T=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
try:
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
except:
cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
if cbs != batch_size:
print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
print(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
# sampling
T, C, H, W = shape
size = (batch_size, T, C, H, W)
print(f"Data shape for DPM-Solver sampling is {size}, sampling steps {S}")
device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T
ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
ns,
model_type=MODEL_TYPES[self.model.parameterization],
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(
img,
steps=S,
skip_type="time_uniform",
method="multistep",
order=2,
lower_order_final=True,
)
return x.to(device), None
"""SAMPLING ONLY."""
import numpy as np
from tqdm import tqdm
import torch
from core.models.utils_diffusion import (
make_ddim_sampling_parameters,
make_ddim_time_steps,
)
from core.common import noise_like
class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_time_steps = model.num_time_steps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
if ddim_eta != 0:
raise ValueError("ddim_eta must be 0 for PLMS")
self.ddim_time_steps = make_ddim_time_steps(
ddim_discr_method=ddim_discretize,
num_ddim_time_steps=ddim_num_steps,
num_ddpm_time_steps=self.ddpm_num_time_steps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_time_steps
), "alphas have to be defined for each timestep"
def to_torch(x):
return x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_time_steps=self.ddim_time_steps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
print(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f"Data shape for PLMS sampling is {size}")
samples, intermediates = self.plms_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
time_steps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if time_steps is None:
time_steps = (
self.ddpm_num_time_steps
if ddim_use_original_steps
else self.ddim_time_steps
)
elif time_steps is not None and not ddim_use_original_steps:
subset_end = (
int(
min(time_steps / self.ddim_time_steps.shape[0], 1)
* self.ddim_time_steps.shape[0]
)
- 1
)
time_steps = self.ddim_time_steps[:subset_end]
intermediates = {"x_inter": [img], "pred_x0": [img]}
time_range = (
list(reversed(range(0, time_steps)))
if ddim_use_original_steps
else np.flip(time_steps)
)
total_steps = time_steps if ddim_use_original_steps else time_steps.shape[0]
print(f"Running PLMS Sampling with {total_steps} time_steps")
iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full(
(b,),
time_range[min(i + 1, len(time_range) - 1)],
device=device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts)
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_plms(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)
intermediates["pred_x0"].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_plms(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (
55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t
"""SAMPLING ONLY."""
import torch
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
class UniPCSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
def to_torch(x):
return x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
x_T=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
# sampling
T, C, H, W = shape
size = (batch_size, T, C, H, W)
device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T
ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
ns,
model_type="v",
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False)
x = uni_pc.sample(
img,
steps=S,
skip_type="time_uniform",
method="multistep",
order=2,
lower_order_final=True,
)
return x.to(device), None
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment