Commit 63bde97a authored by chenpangpang's avatar chenpangpang
Browse files

feat: 初始提交

parent 9cf8c6f1
Pipeline #1475 failed with stages
in 0 seconds
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Modified by Jiale Xu
# The modifications are subject to the same license as the original.
"""
The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths.
Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class MipRayMarcher2(nn.Module):
def __init__(self, activation_factory):
super().__init__()
self.activation_factory = activation_factory
def run_forward(self, colors, densities, depths, rendering_options, normals=None):
dtype = colors.dtype
deltas = depths[:, :, 1:] - depths[:, :, :-1]
colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
# using factory mode for better usability
densities_mid = self.activation_factory(rendering_options)(densities_mid).to(dtype)
density_delta = densities_mid * deltas
alpha = 1 - torch.exp(-density_delta).to(dtype)
alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
weights = weights.to(dtype)
composite_rgb = torch.sum(weights * colors_mid, -2)
weight_total = weights.sum(2)
# composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
composite_depth = torch.sum(weights * depths_mid, -2)
# clip the composite to min/max range of depths
composite_depth = torch.nan_to_num(composite_depth, float('inf')).to(dtype)
composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
if rendering_options.get('white_back', False):
composite_rgb = composite_rgb + 1 - weight_total
# rendered value scale is 0-1, comment out original mipnerf scaling
# composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
return composite_rgb, composite_depth, weights
def forward(self, colors, densities, depths, rendering_options, normals=None):
if normals is not None:
composite_rgb, composite_depth, composite_normals, weights = self.run_forward(colors, densities, depths, rendering_options, normals)
return composite_rgb, composite_depth, composite_normals, weights
composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options)
return composite_rgb, composite_depth, weights
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Modified by Jiale Xu
# The modifications are subject to the same license as the original.
"""
The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
"""
import torch
class RaySampler(torch.nn.Module):
def __init__(self):
super().__init__()
self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
def forward(self, cam2world_matrix, intrinsics, render_size):
"""
Create batches of rays and return origins and directions.
cam2world_matrix: (N, 4, 4)
intrinsics: (N, 3, 3)
render_size: int
ray_origins: (N, M, 3)
ray_dirs: (N, M, 2)
"""
dtype = cam2world_matrix.dtype
device = cam2world_matrix.device
N, M = cam2world_matrix.shape[0], render_size**2
cam_locs_world = cam2world_matrix[:, :3, 3]
fx = intrinsics[:, 0, 0]
fy = intrinsics[:, 1, 1]
cx = intrinsics[:, 0, 2]
cy = intrinsics[:, 1, 2]
sk = intrinsics[:, 0, 1]
uv = torch.stack(torch.meshgrid(
torch.arange(render_size, dtype=dtype, device=device),
torch.arange(render_size, dtype=dtype, device=device),
indexing='ij',
))
uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size)
y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size)
z_cam = torch.ones((N, M), dtype=dtype, device=device)
x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1).to(dtype)
_opencv2blender = torch.tensor([
[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1],
], dtype=dtype, device=device).unsqueeze(0).repeat(N, 1, 1)
cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender)
world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
ray_dirs = world_rel_points - cam_locs_world[:, None, :]
ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2).to(dtype)
ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)
return ray_origins, ray_dirs
class OrthoRaySampler(torch.nn.Module):
def __init__(self):
super().__init__()
self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
def forward(self, cam2world_matrix, ortho_scale, render_size):
"""
Create batches of rays and return origins and directions.
cam2world_matrix: (N, 4, 4)
ortho_scale: float
render_size: int
ray_origins: (N, M, 3)
ray_dirs: (N, M, 3)
"""
N, M = cam2world_matrix.shape[0], render_size**2
uv = torch.stack(torch.meshgrid(
torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
indexing='ij',
))
uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size)
y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size)
z_cam = torch.zeros((N, M), device=cam2world_matrix.device)
x_lift = (x_cam - 0.5) * ortho_scale
y_lift = (y_cam - 0.5) * ortho_scale
cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1)
_opencv2blender = torch.tensor([
[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1],
], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1)
cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender)
ray_origins = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
ray_dirs_cam = torch.stack([
torch.zeros((N, M), device=cam2world_matrix.device),
torch.zeros((N, M), device=cam2world_matrix.device),
torch.ones((N, M), device=cam2world_matrix.device),
], dim=-1)
ray_dirs = torch.bmm(cam2world_matrix[:, :3, :3], ray_dirs_cam.permute(0, 2, 1)).permute(0, 2, 1)
return ray_origins, ray_dirs
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Modified by Jiale Xu
# The modifications are subject to the same license as the original.
"""
The renderer is a module that takes in rays, decides where to sample along each
ray, and computes pixel colors using the volume rendering equation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .ray_marcher import MipRayMarcher2
from . import math_utils
def generate_planes():
"""
Defines planes by the three vectors that form the "axes" of the
plane. Should work with arbitrary number of planes and planes of
arbitrary orientation.
Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
"""
return torch.tensor([[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]],
[[1, 0, 0],
[0, 0, 1],
[0, 1, 0]],
[[0, 0, 1],
[0, 1, 0],
[1, 0, 0]]], dtype=torch.float32)
def project_onto_planes(planes, coordinates):
"""
Does a projection of a 3D point onto a batch of 2D planes,
returning 2D plane coordinates.
Takes plane axes of shape n_planes, 3, 3
# Takes coordinates of shape N, M, 3
# returns projections of shape N*n_planes, M, 2
"""
N, M, C = coordinates.shape
n_planes, _, _ = planes.shape
coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
projections = torch.bmm(coordinates, inv_planes)
return projections[..., :2]
def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
assert padding_mode == 'zeros'
N, n_planes, C, H, W = plane_features.shape
_, M, _ = coordinates.shape
plane_features = plane_features.view(N*n_planes, C, H, W)
dtype = plane_features.dtype
coordinates = (2/box_warp) * coordinates # add specific box bounds
projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
output_features = torch.nn.functional.grid_sample(
plane_features,
projected_coordinates.to(dtype),
mode=mode,
padding_mode=padding_mode,
align_corners=False,
).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
return output_features
def sample_from_3dgrid(grid, coordinates):
"""
Expects coordinates in shape (batch_size, num_points_per_batch, 3)
Expects grid in shape (1, channels, H, W, D)
(Also works if grid has batch size)
Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
"""
batch_size, n_coords, n_dims = coordinates.shape
sampled_features = torch.nn.functional.grid_sample(
grid.expand(batch_size, -1, -1, -1, -1),
coordinates.reshape(batch_size, 1, 1, -1, n_dims),
mode='bilinear',
padding_mode='zeros',
align_corners=False,
)
N, C, H, W, D = sampled_features.shape
sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
return sampled_features
class ImportanceRenderer(torch.nn.Module):
"""
Modified original version to filter out-of-box samples as TensoRF does.
Reference:
TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277
"""
def __init__(self):
super().__init__()
self.activation_factory = self._build_activation_factory()
self.ray_marcher = MipRayMarcher2(self.activation_factory)
self.plane_axes = generate_planes()
def _build_activation_factory(self):
def activation_factory(options: dict):
if options['clamp_mode'] == 'softplus':
return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better
else:
assert False, "Renderer only supports `clamp_mode`=`softplus`!"
return activation_factory
def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor,
planes: torch.Tensor, decoder: nn.Module, rendering_options: dict):
"""
Additional filtering is applied to filter out-of-box samples.
Modifications made by Zexin He.
"""
# context related variables
batch_size, num_rays, samples_per_ray, _ = depths.shape
device = depths.device
# define sample points with depths
sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
# filter out-of-box samples
mask_inbox = \
(rendering_options['sampler_bbox_min'] <= sample_coordinates) & \
(sample_coordinates <= rendering_options['sampler_bbox_max'])
mask_inbox = mask_inbox.all(-1)
# forward model according to all samples
_out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options)
# set out-of-box samples to zeros(rgb) & -inf(sigma)
SAFE_GUARD = 3
DATA_TYPE = _out['sigma'].dtype
colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE)
densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD
colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox]
# reshape back
colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1])
densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1])
return colors_pass, densities_pass
def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options):
# self.plane_axes = self.plane_axes.to(ray_origins.device)
if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp'])
is_ray_valid = ray_end > ray_start
if torch.any(is_ray_valid).item():
ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
else:
# Create stratified depth samples
depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
# Coarse Pass
colors_coarse, densities_coarse = self._forward_pass(
depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins,
planes=planes, decoder=decoder, rendering_options=rendering_options)
# Fine Pass
N_importance = rendering_options['depth_resolution_importance']
if N_importance > 0:
_, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
colors_fine, densities_fine = self._forward_pass(
depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins,
planes=planes, decoder=decoder, rendering_options=rendering_options)
all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
depths_fine, colors_fine, densities_fine)
rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options)
else:
rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
return rgb_final, depth_final, weights.sum(2)
def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
plane_axes = self.plane_axes.to(planes.device)
sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
out = decoder(sampled_features, sample_directions)
if options.get('density_noise', 0) > 0:
out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise']
return out
def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options):
out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options)
out['sigma'] = self.activation_factory(options)(out['sigma'])
return out
def sort_samples(self, all_depths, all_colors, all_densities):
_, indices = torch.sort(all_depths, dim=-2)
all_depths = torch.gather(all_depths, -2, indices)
all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
return all_depths, all_colors, all_densities
def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2, normals1=None, normals2=None):
all_depths = torch.cat([depths1, depths2], dim = -2)
all_colors = torch.cat([colors1, colors2], dim = -2)
all_densities = torch.cat([densities1, densities2], dim = -2)
if normals1 is not None and normals2 is not None:
all_normals = torch.cat([normals1, normals2], dim = -2)
else:
all_normals = None
_, indices = torch.sort(all_depths, dim=-2)
all_depths = torch.gather(all_depths, -2, indices)
all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
if all_normals is not None:
all_normals = torch.gather(all_normals, -2, indices.expand(-1, -1, -1, all_normals.shape[-1]))
return all_depths, all_colors, all_normals, all_densities
return all_depths, all_colors, all_densities
def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
"""
Return depths of approximately uniformly spaced samples along rays.
"""
N, M, _ = ray_origins.shape
if disparity_space_sampling:
depths_coarse = torch.linspace(0,
1,
depth_resolution,
device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
depth_delta = 1/(depth_resolution - 1)
depths_coarse += torch.rand_like(depths_coarse) * depth_delta
depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
else:
if type(ray_start) == torch.Tensor:
depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
else:
depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
depths_coarse += torch.rand_like(depths_coarse) * depth_delta
return depths_coarse
def sample_importance(self, z_vals, weights, N_importance):
"""
Return depths of importance sampled points along rays. See NeRF importance sampling for more.
"""
with torch.no_grad():
batch_size, num_rays, samples_per_ray, _ = z_vals.shape
z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher
# smooth weights
weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1), 2, 1, padding=1)
weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
weights = weights + 0.01
z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
return importance_z_vals
def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
"""
Sample @N_importance samples from @bins with distribution defined by @weights.
Inputs:
bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
weights: (N_rays, N_samples_)
N_importance: the number of samples to draw from the distribution
det: deterministic or not
eps: a small number to prevent division by zero
Outputs:
samples: the sampled samples
"""
N_rays, N_samples_ = weights.shape
weights = weights + eps # prevent division by zero (don't do inplace op!)
pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
# padded to 0~1 inclusive
if det:
u = torch.linspace(0, 1, N_importance, device=bins.device)
u = u.expand(N_rays, N_importance)
else:
u = torch.rand(N_rays, N_importance, device=bins.device)
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.clamp_min(inds-1, 0)
above = torch.clamp_max(inds, N_samples_)
inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
denom = cdf_g[...,1]-cdf_g[...,0]
denom[denom<eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled
# anyway, therefore any value for it is fine (set to 1 here)
samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0])
return samples
import torch
import torch.nn.functional as F
import numpy as np
def pad_camera_extrinsics_4x4(extrinsics):
if extrinsics.shape[-2] == 4:
return extrinsics
padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics)
if extrinsics.ndim == 3:
padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1)
extrinsics = torch.cat([extrinsics, padding], dim=-2)
return extrinsics
def center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None):
"""
Create OpenGL camera extrinsics from camera locations and look-at position.
camera_position: (M, 3) or (3,)
look_at: (3)
up_world: (3)
return: (M, 3, 4) or (3, 4)
"""
# by default, looking at the origin and world up is z-axis
if look_at is None:
look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
if up_world is None:
up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
if camera_position.ndim == 2:
look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
# OpenGL camera: z-backward, x-right, y-up
z_axis = camera_position - look_at
z_axis = F.normalize(z_axis, dim=-1).float()
x_axis = torch.linalg.cross(up_world, z_axis, dim=-1)
x_axis = F.normalize(x_axis, dim=-1).float()
y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1)
y_axis = F.normalize(y_axis, dim=-1).float()
extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
extrinsics = pad_camera_extrinsics_4x4(extrinsics)
return extrinsics
def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5):
azimuths = np.deg2rad(azimuths)
elevations = np.deg2rad(elevations)
xs = radius * np.cos(elevations) * np.cos(azimuths)
ys = radius * np.cos(elevations) * np.sin(azimuths)
zs = radius * np.sin(elevations)
cam_locations = np.stack([xs, ys, zs], axis=-1)
cam_locations = torch.from_numpy(cam_locations).float()
c2ws = center_looking_at_camera_pose(cam_locations)
return c2ws
def get_circular_camera_poses(M=120, radius=2.5, elevation=30.0):
# M: number of circular views
# radius: camera dist to center
# elevation: elevation degrees of the camera
# return: (M, 4, 4)
assert M > 0 and radius > 0
elevation = np.deg2rad(elevation)
camera_positions = []
for i in range(M):
azimuth = 2 * np.pi * i / M
x = radius * np.cos(elevation) * np.cos(azimuth)
y = radius * np.cos(elevation) * np.sin(azimuth)
z = radius * np.sin(elevation)
camera_positions.append([x, y, z])
camera_positions = np.array(camera_positions)
camera_positions = torch.from_numpy(camera_positions).float()
extrinsics = center_looking_at_camera_pose(camera_positions)
return extrinsics
def FOV_to_intrinsics(fov, device='cpu'):
"""
Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
Note the intrinsics are returned as normalized by image size, rather than in pixel units.
Assumes principal point is at image center.
"""
focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5)
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
return intrinsics
def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0):
"""
Get the input camera parameters.
"""
azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float)
elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float)
c2ws = spherical_camera_pose(azimuths, elevations, radius)
c2ws = c2ws.float().flatten(-2)
Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2)
extrinsics = c2ws[:, :12]
intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1)
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
return cameras.unsqueeze(0).repeat(batch_size, 1, 1)
import os
import imageio
import rembg
import torch
import numpy as np
import PIL.Image
from PIL import Image
from typing import Any
def remove_background(image: PIL.Image.Image,
rembg_session: Any = None,
force: bool = False,
**rembg_kwargs,
) -> PIL.Image.Image:
do_remove = True
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
do_remove = False
do_remove = do_remove or force
if do_remove:
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
return image
def resize_foreground(
image: PIL.Image.Image,
ratio: float,
) -> PIL.Image.Image:
image = np.array(image)
assert image.shape[-1] == 4
alpha = np.where(image[..., 3] > 0)
y1, y2, x1, x2 = (
alpha[0].min(),
alpha[0].max(),
alpha[1].min(),
alpha[1].max(),
)
# crop the foreground
fg = image[y1:y2, x1:x2]
# pad to square
size = max(fg.shape[0], fg.shape[1])
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
new_image = np.pad(
fg,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((0, 0), (0, 0), (0, 0)),
)
# compute padding according to the ratio
new_size = int(new_image.shape[0] / ratio)
# pad to size, double side
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
new_image = np.pad(
new_image,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((0, 0), (0, 0), (0, 0)),
)
new_image = PIL.Image.fromarray(new_image)
return new_image
def images_to_video(
images: torch.Tensor,
output_path: str,
fps: int = 30,
) -> None:
# images: (N, C, H, W)
video_dir = os.path.dirname(output_path)
video_name = os.path.basename(output_path)
os.makedirs(video_dir, exist_ok=True)
frames = []
for i in range(len(images)):
frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
f"Frame shape mismatch: {frame.shape} vs {images.shape}"
assert frame.min() >= 0 and frame.max() <= 255, \
f"Frame value out of range: {frame.min()} ~ {frame.max()}"
frames.append(frame)
imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10)
def save_video(
frames: torch.Tensor,
output_path: str,
fps: int = 30,
) -> None:
# images: (N, C, H, W)
frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames]
writer = imageio.get_writer(output_path, fps=fps)
for frame in frames:
writer.append_data(frame)
writer.close()
\ No newline at end of file
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import torch
import xatlas
import trimesh
import cv2
import numpy as np
import nvdiffrast.torch as dr
from PIL import Image
def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fpath):
pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])
facenp_fx3 = facenp_fx3[:, [2, 1, 0]]
mesh = trimesh.Trimesh(
vertices=pointnp_px3,
faces=facenp_fx3,
vertex_colors=colornp_px3,
)
mesh.export(fpath, 'obj')
def save_glb(pointnp_px3, facenp_fx3, colornp_px3, fpath):
pointnp_px3 = pointnp_px3 @ np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]])
mesh = trimesh.Trimesh(
vertices=pointnp_px3,
faces=facenp_fx3,
vertex_colors=colornp_px3,
)
mesh.export(fpath, 'glb')
def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname):
import os
fol, na = os.path.split(fname)
na, _ = os.path.splitext(na)
matname = '%s/%s.mtl' % (fol, na)
fid = open(matname, 'w')
fid.write('newmtl material_0\n')
fid.write('Kd 1 1 1\n')
fid.write('Ka 0 0 0\n')
fid.write('Ks 0.4 0.4 0.4\n')
fid.write('Ns 10\n')
fid.write('illum 2\n')
fid.write('map_Kd %s.png\n' % na)
fid.close()
####
fid = open(fname, 'w')
fid.write('mtllib %s.mtl\n' % na)
for pidx, p in enumerate(pointnp_px3):
pp = p
fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2]))
for pidx, p in enumerate(tcoords_px2):
pp = p
fid.write('vt %f %f\n' % (pp[0], pp[1]))
fid.write('usemtl material_0\n')
for i, f in enumerate(facenp_fx3):
f1 = f + 1
f2 = facetex_fx3[i] + 1
fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
fid.close()
# save texture map
lo, hi = 0, 1
img = np.asarray(texmap_hxwx3, dtype=np.float32)
img = (img - lo) * (255 / (hi - lo))
img = img.clip(0, 255)
mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True)
mask = (mask <= 3.0).astype(np.float32)
kernel = np.ones((3, 3), 'uint8')
dilate_img = cv2.dilate(img, kernel, iterations=1)
img = img * (1 - mask) + dilate_img * mask
img = img.clip(0, 255).astype(np.uint8)
Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(f'{fol}/{na}.png')
def loadobj(meshfile):
v = []
f = []
meshfp = open(meshfile, 'r')
for line in meshfp.readlines():
data = line.strip().split(' ')
data = [da for da in data if len(da) > 0]
if len(data) != 4:
continue
if data[0] == 'v':
v.append([float(d) for d in data[1:]])
if data[0] == 'f':
data = [da.split('/')[0] for da in data]
f.append([int(d) for d in data[1:]])
meshfp.close()
# torch need int64
facenp_fx3 = np.array(f, dtype=np.int64) - 1
pointnp_px3 = np.array(v, dtype=np.float32)
return pointnp_px3, facenp_fx3
def loadobjtex(meshfile):
v = []
vt = []
f = []
ft = []
meshfp = open(meshfile, 'r')
for line in meshfp.readlines():
data = line.strip().split(' ')
data = [da for da in data if len(da) > 0]
if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)):
continue
if data[0] == 'v':
assert len(data) == 4
v.append([float(d) for d in data[1:]])
if data[0] == 'vt':
if len(data) == 3 or len(data) == 4:
vt.append([float(d) for d in data[1:3]])
if data[0] == 'f':
data = [da.split('/') for da in data]
if len(data) == 4:
f.append([int(d[0]) for d in data[1:]])
ft.append([int(d[1]) for d in data[1:]])
elif len(data) == 5:
idx1 = [1, 2, 3]
data1 = [data[i] for i in idx1]
f.append([int(d[0]) for d in data1])
ft.append([int(d[1]) for d in data1])
idx2 = [1, 3, 4]
data2 = [data[i] for i in idx2]
f.append([int(d[0]) for d in data2])
ft.append([int(d[1]) for d in data2])
meshfp.close()
# torch need int64
facenp_fx3 = np.array(f, dtype=np.int64) - 1
ftnp_fx3 = np.array(ft, dtype=np.int64) - 1
pointnp_px3 = np.array(v, dtype=np.float32)
uvs = np.array(vt, dtype=np.float32)
return pointnp_px3, facenp_fx3, uvs, ftnp_fx3
# ==============================================================================================
def interpolate(attr, rast, attr_idx, rast_db=None):
return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution):
vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy())
# Convert to tensors
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
# mesh_v_tex. ture
uv_clip = uvs[None, ...] * 2.0 - 1.0
# pad to four component coordinate
uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
# rasterize
rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution))
# Interpolate world space position
gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
mask = rast[..., 3:4] > 0
return uvs, mesh_tex_idx, gb_pos, mask
import importlib
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
return total_params
def instantiate_from_config(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
import os, sys
import argparse
import shutil
import subprocess
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from src.utils.train_util import instantiate_from_config
@rank_zero_only
def rank_zero_print(*args):
print(*args)
def get_parser(**parser_kwargs):
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument(
"-r",
"--resume",
type=str,
default=None,
help="resume from checkpoint",
)
parser.add_argument(
"--resume_weights_only",
action="store_true",
help="only resume model weights",
)
parser.add_argument(
"-b",
"--base",
type=str,
default="base_config.yaml",
help="path to base configs",
)
parser.add_argument(
"-n",
"--name",
type=str,
default="",
help="experiment name",
)
parser.add_argument(
"--num_nodes",
type=int,
default=1,
help="number of nodes to use",
)
parser.add_argument(
"--gpus",
type=str,
default="0,",
help="gpu ids to use",
)
parser.add_argument(
"-s",
"--seed",
type=int,
default=42,
help="seed for seed_everything",
)
parser.add_argument(
"-l",
"--logdir",
type=str,
default="logs",
help="directory for logging data",
)
return parser
class SetupCallback(Callback):
def __init__(self, resume, logdir, ckptdir, cfgdir, config):
super().__init__()
self.resume = resume
self.logdir = logdir
self.ckptdir = ckptdir
self.cfgdir = cfgdir
self.config = config
def on_fit_start(self, trainer, pl_module):
if trainer.global_rank == 0:
# Create logdirs and save configs
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
rank_zero_print("Project config")
rank_zero_print(OmegaConf.to_yaml(self.config))
OmegaConf.save(self.config,
os.path.join(self.cfgdir, "project.yaml"))
class CodeSnapshot(Callback):
"""
Modified from https://github.com/threestudio-project/threestudio/blob/main/threestudio/utils/callbacks.py#L60
"""
def __init__(self, savedir):
self.savedir = savedir
def get_file_list(self):
return [
b.decode()
for b in set(
subprocess.check_output(
'git ls-files -- ":!:configs/*"', shell=True
).splitlines()
)
| set( # hard code, TODO: use config to exclude folders or files
subprocess.check_output(
"git ls-files --others --exclude-standard", shell=True
).splitlines()
)
]
@rank_zero_only
def save_code_snapshot(self):
os.makedirs(self.savedir, exist_ok=True)
for f in self.get_file_list():
if not os.path.exists(f) or os.path.isdir(f):
continue
os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True)
shutil.copyfile(f, os.path.join(self.savedir, f))
def on_fit_start(self, trainer, pl_module):
try:
self.save_code_snapshot()
except:
rank_zero_warn(
"Code snapshot is not saved. Please make sure you have git installed and are in a git repository."
)
if __name__ == "__main__":
# add cwd for convenience and to make classes in this file available when
# running as `python main.py`
sys.path.append(os.getcwd())
parser = get_parser()
opt, unknown = parser.parse_known_args()
cfg_fname = os.path.split(opt.base)[-1]
cfg_name = os.path.splitext(cfg_fname)[0]
exp_name = "-" + opt.name if opt.name != "" else ""
logdir = os.path.join(opt.logdir, cfg_name+exp_name)
ckptdir = os.path.join(logdir, "checkpoints")
cfgdir = os.path.join(logdir, "configs")
codedir = os.path.join(logdir, "code")
seed_everything(opt.seed)
# init configs
config = OmegaConf.load(opt.base)
lightning_config = config.lightning
trainer_config = lightning_config.trainer
trainer_config["accelerator"] = "gpu"
rank_zero_print(f"Running on GPUs {opt.gpus}")
ngpu = len(opt.gpus.strip(",").split(','))
trainer_config['devices'] = ngpu
trainer_opt = argparse.Namespace(**trainer_config)
lightning_config.trainer = trainer_config
# model
model = instantiate_from_config(config.model)
if opt.resume and opt.resume_weights_only:
model = model.__class__.load_from_checkpoint(opt.resume, **config.model.params)
model.logdir = logdir
# trainer and callbacks
trainer_kwargs = dict()
# logger
default_logger_cfg = {
"target": "pytorch_lightning.loggers.TensorBoardLogger",
"params": {
"name": "tensorboard",
"save_dir": logdir,
"version": "0",
}
}
logger_cfg = OmegaConf.merge(default_logger_cfg)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
# model checkpoint
default_modelckpt_cfg = {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{step:08}",
"verbose": True,
"save_last": True,
"every_n_train_steps": 5000,
"save_top_k": -1, # save all checkpoints
}
}
if "modelcheckpoint" in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint
else:
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
# callbacks
default_callbacks_cfg = {
"setup_callback": {
"target": "train.SetupCallback",
"params": {
"resume": opt.resume,
"logdir": logdir,
"ckptdir": ckptdir,
"cfgdir": cfgdir,
"config": config,
}
},
"learning_rate_logger": {
"target": "pytorch_lightning.callbacks.LearningRateMonitor",
"params": {
"logging_interval": "step",
}
},
"code_snapshot": {
"target": "train.CodeSnapshot",
"params": {
"savedir": codedir,
}
},
}
default_callbacks_cfg["checkpoint_callback"] = modelckpt_cfg
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
else:
callbacks_cfg = OmegaConf.create()
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
trainer_kwargs["callbacks"] = [
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
trainer_kwargs['precision'] = '32-true'
trainer_kwargs["strategy"] = DDPStrategy(find_unused_parameters=True)
# trainer
trainer = Trainer(**trainer_config, **trainer_kwargs, num_nodes=opt.num_nodes)
trainer.logdir = logdir
# data
data = instantiate_from_config(config.data)
data.prepare_data()
data.setup("fit")
# configure learning rate
base_lr = config.model.base_learning_rate
if 'accumulate_grad_batches' in lightning_config.trainer:
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
else:
accumulate_grad_batches = 1
rank_zero_print(f"accumulate_grad_batches = {accumulate_grad_batches}")
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
model.learning_rate = base_lr
rank_zero_print("++++ NOT USING LR SCALING ++++")
rank_zero_print(f"Setting learning rate to {model.learning_rate:.2e}")
# run training loop
if opt.resume and not opt.resume_weights_only:
trainer.fit(model, data, ckpt_path=opt.resume)
else:
trainer.fit(model, data)
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from tqdm import tqdm
from torchvision.transforms import v2
from torchvision.utils import make_grid, save_image
from einops import rearrange
from src.utils.train_util import instantiate_from_config
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, DDPMScheduler, UNet2DConditionModel
from .pipeline import RefOnlyNoisedUNet
def scale_latents(latents):
latents = (latents - 0.22) * 0.75
return latents
def unscale_latents(latents):
latents = latents / 0.75 + 0.22
return latents
def scale_image(image):
image = image * 0.5 / 0.8
return image
def unscale_image(image):
image = image / 0.5 * 0.8
return image
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
class MVDiffusion(pl.LightningModule):
def __init__(
self,
stable_diffusion_config,
drop_cond_prob=0.1,
):
super(MVDiffusion, self).__init__()
self.drop_cond_prob = drop_cond_prob
self.register_schedule()
# init modules
pipeline = DiffusionPipeline.from_pretrained(**stable_diffusion_config)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config, timestep_spacing='trailing'
)
self.pipeline = pipeline
train_sched = DDPMScheduler.from_config(self.pipeline.scheduler.config)
if isinstance(self.pipeline.unet, UNet2DConditionModel):
self.pipeline.unet = RefOnlyNoisedUNet(self.pipeline.unet, train_sched, self.pipeline.scheduler)
self.train_scheduler = train_sched # use ddpm scheduler during training
self.unet = pipeline.unet
# validation output buffer
self.validation_step_outputs = []
def register_schedule(self):
self.num_timesteps = 1000
# replace scaled_linear schedule with linear schedule as Zero123++
beta_start = 0.00085
beta_end = 0.0120
betas = torch.linspace(beta_start, beta_end, 1000, dtype=torch.float32)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0)
self.register_buffer('betas', betas.float())
self.register_buffer('alphas_cumprod', alphas_cumprod.float())
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float())
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod).float())
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod).float())
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod).float())
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1).float())
def on_fit_start(self):
device = torch.device(f'cuda:{self.global_rank}')
self.pipeline.to(device)
if self.global_rank == 0:
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
def prepare_batch_data(self, batch):
# prepare stable diffusion input
cond_imgs = batch['cond_imgs'] # (B, C, H, W)
cond_imgs = cond_imgs.to(self.device)
# random resize the condition image
cond_size = np.random.randint(128, 513)
cond_imgs = v2.functional.resize(cond_imgs, cond_size, interpolation=3, antialias=True).clamp(0, 1)
target_imgs = batch['target_imgs'] # (B, 6, C, H, W)
target_imgs = v2.functional.resize(target_imgs, 320, interpolation=3, antialias=True).clamp(0, 1)
target_imgs = rearrange(target_imgs, 'b (x y) c h w -> b c (x h) (y w)', x=3, y=2) # (B, C, 3H, 2W)
target_imgs = target_imgs.to(self.device)
return cond_imgs, target_imgs
@torch.no_grad()
def forward_vision_encoder(self, images):
dtype = next(self.pipeline.vision_encoder.parameters()).dtype
image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])]
image_pt = self.pipeline.feature_extractor_clip(images=image_pil, return_tensors="pt").pixel_values
image_pt = image_pt.to(device=self.device, dtype=dtype)
global_embeds = self.pipeline.vision_encoder(image_pt, output_hidden_states=False).image_embeds
global_embeds = global_embeds.unsqueeze(-2)
encoder_hidden_states = self.pipeline._encode_prompt("", self.device, 1, False)[0]
ramp = global_embeds.new_tensor(self.pipeline.config.ramping_coefficients).unsqueeze(-1)
encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
return encoder_hidden_states
@torch.no_grad()
def encode_condition_image(self, images):
dtype = next(self.pipeline.vae.parameters()).dtype
image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])]
image_pt = self.pipeline.feature_extractor_vae(images=image_pil, return_tensors="pt").pixel_values
image_pt = image_pt.to(device=self.device, dtype=dtype)
latents = self.pipeline.vae.encode(image_pt).latent_dist.sample()
return latents
@torch.no_grad()
def encode_target_images(self, images):
dtype = next(self.pipeline.vae.parameters()).dtype
# equals to scaling images to [-1, 1] first and then call scale_image
images = (images - 0.5) / 0.8 # [-0.625, 0.625]
posterior = self.pipeline.vae.encode(images.to(dtype)).latent_dist
latents = posterior.sample() * self.pipeline.vae.config.scaling_factor
latents = scale_latents(latents)
return latents
def forward_unet(self, latents, t, prompt_embeds, cond_latents):
dtype = next(self.pipeline.unet.parameters()).dtype
latents = latents.to(dtype)
prompt_embeds = prompt_embeds.to(dtype)
cond_latents = cond_latents.to(dtype)
cross_attention_kwargs = dict(cond_lat=cond_latents)
pred_noise = self.pipeline.unet(
latents,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
return pred_noise
def predict_start_from_z_and_v(self, x_t, t, v):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def get_v(self, x, noise, t):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
)
def training_step(self, batch, batch_idx):
# get input
cond_imgs, target_imgs = self.prepare_batch_data(batch)
# sample random timestep
B = cond_imgs.shape[0]
t = torch.randint(0, self.num_timesteps, size=(B,)).long().to(self.device)
# classifier-free guidance
if np.random.rand() < self.drop_cond_prob:
prompt_embeds = self.pipeline._encode_prompt([""]*B, self.device, 1, False)
cond_latents = self.encode_condition_image(torch.zeros_like(cond_imgs))
else:
prompt_embeds = self.forward_vision_encoder(cond_imgs)
cond_latents = self.encode_condition_image(cond_imgs)
latents = self.encode_target_images(target_imgs)
noise = torch.randn_like(latents)
latents_noisy = self.train_scheduler.add_noise(latents, noise, t)
v_pred = self.forward_unet(latents_noisy, t, prompt_embeds, cond_latents)
v_target = self.get_v(latents, noise, t)
loss, loss_dict = self.compute_loss(v_pred, v_target)
# logging
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False)
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
if self.global_step % 500 == 0 and self.global_rank == 0:
with torch.no_grad():
latents_pred = self.predict_start_from_z_and_v(latents_noisy, t, v_pred)
latents = unscale_latents(latents_pred)
images = unscale_image(self.pipeline.vae.decode(latents / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1]
images = (images * 0.5 + 0.5).clamp(0, 1)
images = torch.cat([target_imgs, images], dim=-2)
grid = make_grid(images, nrow=images.shape[0], normalize=True, value_range=(0, 1))
save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
return loss
def compute_loss(self, noise_pred, noise_gt):
loss = F.mse_loss(noise_pred, noise_gt)
prefix = 'train'
loss_dict = {}
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
@torch.no_grad()
def validation_step(self, batch, batch_idx):
# get input
cond_imgs, target_imgs = self.prepare_batch_data(batch)
images_pil = [v2.functional.to_pil_image(cond_imgs[i]) for i in range(cond_imgs.shape[0])]
outputs = []
for cond_img in images_pil:
latent = self.pipeline(cond_img, num_inference_steps=75, output_type='latent').images
image = unscale_image(self.pipeline.vae.decode(latent / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1]
image = (image * 0.5 + 0.5).clamp(0, 1)
outputs.append(image)
outputs = torch.cat(outputs, dim=0).to(self.device)
images = torch.cat([target_imgs, outputs], dim=-2)
self.validation_step_outputs.append(images)
@torch.no_grad()
def on_validation_epoch_end(self):
images = torch.cat(self.validation_step_outputs, dim=0)
all_images = self.all_gather(images)
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
if self.global_rank == 0:
grid = make_grid(all_images, nrow=8, normalize=True, value_range=(0, 1))
save_image(grid, os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png'))
self.validation_step_outputs.clear() # free memory
def configure_optimizers(self):
lr = self.learning_rate
optimizer = torch.optim.AdamW(self.unet.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
from typing import Any, Dict, Optional
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
import numpy
import torch
import torch.nn as nn
import torch.utils.checkpoint
import torch.distributed
import transformers
from collections import OrderedDict
from PIL import Image
from torchvision import transforms
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
EulerAncestralDiscreteScheduler,
UNet2DConditionModel,
ImagePipelineOutput
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0
from diffusers.utils.import_utils import is_xformers_available
def to_rgb_image(maybe_rgba: Image.Image):
if maybe_rgba.mode == 'RGB':
return maybe_rgba
elif maybe_rgba.mode == 'RGBA':
rgba = maybe_rgba
img = numpy.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
img = Image.fromarray(img, 'RGB')
img.paste(rgba, mask=rgba.getchannel('A'))
return img
else:
raise ValueError("Unsupported image type.", maybe_rgba.mode)
class ReferenceOnlyAttnProc(torch.nn.Module):
def __init__(
self,
chained_proc,
enabled=False,
name=None
) -> None:
super().__init__()
self.enabled = enabled
self.chained_proc = chained_proc
self.name = name
def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
mode="w", ref_dict: dict = None, is_cfg_guidance = False
) -> Any:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if self.enabled and is_cfg_guidance:
res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask)
hidden_states = hidden_states[1:]
encoder_hidden_states = encoder_hidden_states[1:]
if self.enabled:
if mode == 'w':
ref_dict[self.name] = encoder_hidden_states
elif mode == 'r':
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
elif mode == 'm':
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1)
else:
assert False, mode
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
if self.enabled and is_cfg_guidance:
res = torch.cat([res0, res])
return res
class RefOnlyNoisedUNet(torch.nn.Module):
def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None:
super().__init__()
self.unet = unet
self.train_sched = train_sched
self.val_sched = val_sched
unet_lora_attn_procs = dict()
for name, _ in unet.attn_processors.items():
if torch.__version__ >= '2.0':
default_attn_proc = AttnProcessor2_0()
elif is_xformers_available():
default_attn_proc = XFormersAttnProcessor()
else:
default_attn_proc = AttnProcessor()
unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
)
unet.set_attn_processor(unet_lora_attn_procs)
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.unet, name)
def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs):
if is_cfg_guidance:
encoder_hidden_states = encoder_hidden_states[1:]
class_labels = class_labels[1:]
self.unet(
noisy_cond_lat, timestep,
encoder_hidden_states=encoder_hidden_states,
class_labels=class_labels,
cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
**kwargs
)
def forward(
self, sample, timestep, encoder_hidden_states, class_labels=None,
*args, cross_attention_kwargs,
down_block_res_samples=None, mid_block_res_sample=None,
**kwargs
):
cond_lat = cross_attention_kwargs['cond_lat']
is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False)
noise = torch.randn_like(cond_lat)
if self.training:
noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
else:
noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
ref_dict = {}
self.forward_cond(
noisy_cond_lat, timestep,
encoder_hidden_states, class_labels,
ref_dict, is_cfg_guidance, **kwargs
)
weight_dtype = self.unet.dtype
return self.unet(
sample, timestep,
encoder_hidden_states, *args,
class_labels=class_labels,
cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance),
down_block_additional_residuals=[
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
] if down_block_res_samples is not None else None,
mid_block_additional_residual=(
mid_block_res_sample.to(dtype=weight_dtype)
if mid_block_res_sample is not None else None
),
**kwargs
)
def scale_latents(latents):
latents = (latents - 0.22) * 0.75
return latents
def unscale_latents(latents):
latents = latents / 0.75 + 0.22
return latents
def scale_image(image):
image = image * 0.5 / 0.8
return image
def unscale_image(image):
image = image / 0.5 * 0.8
return image
class DepthControlUNet(torch.nn.Module):
def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0) -> None:
super().__init__()
self.unet = unet
if controlnet is None:
self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet)
else:
self.controlnet = controlnet
DefaultAttnProc = AttnProcessor2_0
if is_xformers_available():
DefaultAttnProc = XFormersAttnProcessor
self.controlnet.set_attn_processor(DefaultAttnProc())
self.conditioning_scale = conditioning_scale
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.unet, name)
def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs):
cross_attention_kwargs = dict(cross_attention_kwargs)
control_depth = cross_attention_kwargs.pop('control_depth')
down_block_res_samples, mid_block_res_sample = self.controlnet(
sample,
timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_depth,
conditioning_scale=self.conditioning_scale,
return_dict=False,
)
return self.unet(
sample,
timestep,
encoder_hidden_states=encoder_hidden_states,
down_block_res_samples=down_block_res_samples,
mid_block_res_sample=mid_block_res_sample,
cross_attention_kwargs=cross_attention_kwargs
)
class ModuleListDict(torch.nn.Module):
def __init__(self, procs: dict) -> None:
super().__init__()
self.keys = sorted(procs.keys())
self.values = torch.nn.ModuleList(procs[k] for k in self.keys)
def __getitem__(self, key):
return self.values[self.keys.index(key)]
class SuperNet(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
super().__init__()
state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys()))
self.layers = torch.nn.ModuleList(state_dict.values())
self.mapping = dict(enumerate(state_dict.keys()))
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
# .processor for unet, .self_attn for text encoder
self.split_keys = [".processor", ".self_attn"]
# we add a hook to state_dict() and load_state_dict() so that the
# naming fits with `unet.attn_processors`
def map_to(module, state_dict, *args, **kwargs):
new_state_dict = {}
for key, value in state_dict.items():
num = int(key.split(".")[1]) # 0 is always "layers"
new_key = key.replace(f"layers.{num}", module.mapping[num])
new_state_dict[new_key] = value
return new_state_dict
def remap_key(key, state_dict):
for k in self.split_keys:
if k in key:
return key.split(k)[0] + k
return key.split('.')[0]
def map_from(module, state_dict, *args, **kwargs):
all_keys = list(state_dict.keys())
for key in all_keys:
replace_key = remap_key(key, state_dict)
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
state_dict[new_key] = state_dict[key]
del state_dict[key]
self._register_state_dict_hook(map_to)
self._register_load_state_dict_pre_hook(map_from, with_module=True)
class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
tokenizer: transformers.CLIPTokenizer
text_encoder: transformers.CLIPTextModel
vision_encoder: transformers.CLIPVisionModelWithProjection
feature_extractor_clip: transformers.CLIPImageProcessor
unet: UNet2DConditionModel
scheduler: diffusers.schedulers.KarrasDiffusionSchedulers
vae: AutoencoderKL
ramping: nn.Linear
feature_extractor_vae: transformers.CLIPImageProcessor
depth_transforms_multi = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
vision_encoder: transformers.CLIPVisionModelWithProjection,
feature_extractor_clip: CLIPImageProcessor,
feature_extractor_vae: CLIPImageProcessor,
ramping_coefficients: Optional[list] = None,
safety_checker=None,
):
DiffusionPipeline.__init__(self)
self.register_modules(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
unet=unet, scheduler=scheduler, safety_checker=None,
vision_encoder=vision_encoder,
feature_extractor_clip=feature_extractor_clip,
feature_extractor_vae=feature_extractor_vae
)
self.register_to_config(ramping_coefficients=ramping_coefficients)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def prepare(self):
train_sched = DDPMScheduler.from_config(self.scheduler.config)
if isinstance(self.unet, UNet2DConditionModel):
self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval()
def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0):
self.prepare()
self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale)
return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)]))
def encode_condition_image(self, image: torch.Tensor):
image = self.vae.encode(image).latent_dist.sample()
return image
@torch.no_grad()
def __call__(
self,
image: Image.Image = None,
prompt = "",
*args,
num_images_per_prompt: Optional[int] = 1,
guidance_scale=4.0,
depth_image: Image.Image = None,
output_type: Optional[str] = "pil",
width=640,
height=960,
num_inference_steps=28,
return_dict=True,
**kwargs
):
self.prepare()
if image is None:
raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.")
assert not isinstance(image, torch.Tensor)
image = to_rgb_image(image)
image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values
image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values
if depth_image is not None and hasattr(self.unet, "controlnet"):
depth_image = to_rgb_image(depth_image)
depth_image = self.depth_transforms_multi(depth_image).to(
device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype
)
image = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)
cond_lat = self.encode_condition_image(image)
if guidance_scale > 1:
negative_lat = self.encode_condition_image(torch.zeros_like(image))
cond_lat = torch.cat([negative_lat, cond_lat])
encoded = self.vision_encoder(image_2, output_hidden_states=False)
global_embeds = encoded.image_embeds
global_embeds = global_embeds.unsqueeze(-2)
if hasattr(self, "encode_prompt"):
encoder_hidden_states = self.encode_prompt(
prompt,
self.device,
num_images_per_prompt,
False
)[0]
else:
encoder_hidden_states = self._encode_prompt(
prompt,
self.device,
num_images_per_prompt,
False
)
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
cak = dict(cond_lat=cond_lat)
if hasattr(self.unet, "controlnet"):
cak['control_depth'] = depth_image
latents: torch.Tensor = super().__call__(
None,
*args,
cross_attention_kwargs=cak,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=encoder_hidden_states,
num_inference_steps=num_inference_steps,
output_type='latent',
width=width,
height=height,
**kwargs
).images
latents = unscale_latents(latents)
if not output_type == "latent":
image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0])
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
# pip install huggingface-cli
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
model_list = [
"sudo-ai/zero123plus-v1.2",
"TencentARC/InstantMesh"
]
for model_path in model_list:
os.system(
f"huggingface-cli download --resume-download {model_path} --local-dir ./{model_path} --local-dir-use-symlinks False")
#!/bin/bash
cd /root/InstantMesh
python app.py
{
"cells": [
{
"cell_type": "markdown",
"id": "e5c5a211-2ccd-4341-af10-ac546484b91f",
"metadata": {
"tags": []
},
"source": [
"## 说明\n",
"- 启动需要加载模型,需要2分钟左右的时间\n",
"- 启动和重启 Notebook 点上方工具栏中的「重启并运行所有单元格」。出现如下内容就算成功了:\n",
" - `Running on local URL: http://0.0.0.0:7860`\n",
" - `Running on public URL: https://xxxxxxxxxxxxxxx.gradio.live`\n",
"- 通过以下方式开启页面:\n",
" - 控制台打开「自定义服务」了,访问自定义服务端口号设置为7860\n",
" - 直接打开显示的公开链接`public URL`\n",
"\n",
"## 功能介绍\n",
"- 原项目地址:https://github.com/TencentARC/InstantMesh\n",
"- InstantMesh:2D图片到3D模型转化工具,单张图片仅需10秒即可生成高质量3D模型"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53a96614-e2d2-4710-a82b-0d5ca9cb9872",
"metadata": {
"tags": [],
"is_executing": true
},
"outputs": [],
"source": [
"# 启动\n",
"!sh start.sh"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e81ae9d-3a34-43a0-943a-ff5e9d6ce961",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment