Commit 2e7ad6e0 authored by Ruilong Li's avatar Ruilong Li
Browse files

proposal seems working

parent b4286720
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import argparse
import math
import os
import random
import time
from typing import Optional
import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from datasets.utils import Rays, namedtuple_map
from radiance_fields.ngp import NGPradianceField
from utils import set_random_seed
from nerfacc import ContractionType, ray_marching, rendering
from nerfacc.cuda import ray_pdf_query
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def render_image(
# scene
radiance_field: torch.nn.Module,
proposal_nets: torch.nn.Module,
rays: Rays,
scene_aabb: torch.Tensor,
# rendering options
near_plane: Optional[float] = None,
far_plane: Optional[float] = None,
render_step_size: float = 1e-3,
render_bkgd: Optional[torch.Tensor] = None,
cone_angle: float = 0.0,
alpha_thre: float = 0.0,
# test options
test_chunk_size: int = 8192,
):
"""Render the pixels of an image."""
rays_shape = rays.origins.shape
if len(rays_shape) == 3:
height, width, _ = rays_shape
num_rays = height * width
rays = namedtuple_map(
lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays
)
else:
num_rays, _ = rays_shape
def sigma_fn(t_starts, t_ends, ray_indices, net=None):
ray_indices = ray_indices.long()
t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
if net is not None:
return net.query_density(positions)
else:
return radiance_field.query_density(positions)
def rgb_sigma_fn(t_starts, t_ends, ray_indices):
ray_indices = ray_indices.long()
t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
return radiance_field(positions, t_dirs)
results = []
chunk = (
torch.iinfo(torch.int32).max
if radiance_field.training
else test_chunk_size
)
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
packed_info, t_starts, t_ends, proposal_sample_list = ray_marching(
chunk_rays.origins,
chunk_rays.viewdirs,
scene_aabb=scene_aabb,
grid=None,
proposal_nets=proposal_nets,
sigma_fn=sigma_fn,
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
stratified=radiance_field.training,
cone_angle=cone_angle,
alpha_thre=alpha_thre,
)
rgb, opacity, depth, weights = rendering(
rgb_sigma_fn,
packed_info,
t_starts,
t_ends,
render_bkgd=render_bkgd,
)
if radiance_field.training:
proposal_sample_list.append(
(packed_info, t_starts, t_ends, weights)
)
chunk_results = [rgb, opacity, depth, len(t_starts)]
results.append(chunk_results)
colors, opacities, depths, n_rendering_samples = [
torch.cat(r, dim=0) if isinstance(r[0], torch.Tensor) else r
for r in zip(*results)
]
return (
colors.view((*rays_shape[:-1], -1)),
opacities.view((*rays_shape[:-1], -1)),
depths.view((*rays_shape[:-1], -1)),
sum(n_rendering_samples),
proposal_sample_list if radiance_field.training else None,
)
if __name__ == "__main__":
device = "cuda:0"
set_random_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"--train_split",
type=str,
default="trainval",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# nerf synthetic
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
# mipnerf360 unbounded
"garden",
"bicycle",
"bonsai",
"counter",
"kitchen",
"room",
"stump",
],
help="which scene to use",
)
parser.add_argument(
"--aabb",
type=lambda s: [float(item) for item in s.split(",")],
default="-1.5,-1.5,-1.5,1.5,1.5,1.5",
help="delimited list input",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
parser.add_argument(
"--unbounded",
action="store_true",
help="whether to use unbounded rendering",
)
parser.add_argument(
"--auto_aabb",
action="store_true",
help="whether to automatically compute the aabb",
)
parser.add_argument("--cone_angle", type=float, default=0.0)
args = parser.parse_args()
render_n_samples = 256
# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.unbounded:
from datasets.nerf_360_v2 import SubjectLoader
data_root_fp = "/home/ruilongli/data/360_v2/"
target_sample_batch_size = 1 << 20
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
else:
from datasets.nerf_synthetic import SubjectLoader
data_root_fp = "/home/ruilongli/data/nerf_synthetic/"
target_sample_batch_size = 1 << 20
train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=data_root_fp,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
**train_dataset_kwargs,
)
train_dataset.images = train_dataset.images.to(device)
train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
train_dataset.K = train_dataset.K.to(device)
test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=data_root_fp,
split="test",
num_rays=None,
**test_dataset_kwargs,
)
test_dataset.images = test_dataset.images.to(device)
test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
test_dataset.K = test_dataset.K.to(device)
if args.auto_aabb:
camera_locs = torch.cat(
[train_dataset.camtoworlds, test_dataset.camtoworlds]
)[:, :3, -1]
args.aabb = torch.cat(
[camera_locs.min(dim=0).values, camera_locs.max(dim=0).values]
).tolist()
print("Using auto aabb", args.aabb)
# setup the scene bounding box.
if args.unbounded:
print("Using unbounded rendering")
contraction_type = ContractionType.UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH
scene_aabb = None
near_plane = 0.2
far_plane = 1e4
render_step_size = 1e-2
alpha_thre = 1e-2
else:
contraction_type = ContractionType.AABB
scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device)
near_plane = None
far_plane = None
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max()
* math.sqrt(3)
/ render_n_samples
).item()
alpha_thre = 0.0
proposal_nets = torch.nn.ModuleList(
[
NGPradianceField(
aabb=args.aabb,
use_viewdirs=False,
hidden_dim=16,
max_res=64,
geo_feat_dim=0,
n_levels=5,
log2_hashmap_size=17,
),
# NGPradianceField(
# aabb=args.aabb,
# use_viewdirs=False,
# hidden_dim=16,
# max_res=256,
# geo_feat_dim=0,
# n_levels=5,
# log2_hashmap_size=17,
# ),
]
).to(device)
# setup the radiance field we want to train.
max_steps = 20000
grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPradianceField(
aabb=args.aabb,
unbounded=args.unbounded,
).to(device)
optimizer = torch.optim.Adam(
list(radiance_field.parameters()) + list(proposal_nets.parameters()),
lr=1e-2,
eps=1e-15,
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10],
gamma=0.33,
)
# training
step = 0
tic = time.time()
for epoch in range(10000000):
for i in range(len(train_dataset)):
radiance_field.train()
data = train_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# render
(
rgb,
acc,
depth,
n_rendering_samples,
proposal_sample_list,
) = render_image(
radiance_field,
proposal_nets,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
alpha_thre=alpha_thre,
)
if n_rendering_samples == 0:
continue
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays
* (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
alive_ray_mask = acc.squeeze(-1) > 0
# compute loss
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
(
packed_info,
t_starts,
t_ends,
weights,
) = proposal_sample_list[-1]
for (
proposal_packed_info,
proposal_t_starts,
proposal_t_ends,
proposal_weights,
) in proposal_sample_list[:-1]:
proposal_weights_gt = ray_pdf_query(
packed_info,
t_starts,
t_ends,
weights.detach(),
proposal_packed_info,
proposal_t_starts,
proposal_t_ends,
).detach()
torch.cuda.synchronize()
loss_interval = (
torch.clamp(proposal_weights_gt - proposal_weights, min=0)
) ** 2 / (proposal_weights + torch.finfo(torch.float32).eps)
loss_interval = loss_interval.mean()
loss += loss_interval * 1.0
optimizer.zero_grad()
# do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
if step % 100 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | loss_interval={loss_interval:.5f} "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
)
if step >= 0 and step % 1000 == 0 and step > 0:
# evaluation
radiance_field.eval()
psnrs = []
with torch.no_grad():
for i in tqdm.tqdm(range(len(test_dataset))):
data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# rendering
rgb, acc, depth, _, _ = render_image(
radiance_field,
proposal_nets,
rays,
scene_aabb,
# rendering options
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
alpha_thre=alpha_thre,
# test options
test_chunk_size=args.test_chunk_size,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
imageio.imwrite(
"acc_binary_test.png",
((acc > 0).float().cpu().numpy() * 255).astype(
np.uint8
),
)
imageio.imwrite(
"rgb_test.png",
(rgb.cpu().numpy() * 255).astype(np.uint8),
)
break
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: psnr_avg={psnr_avg}")
train_dataset.training = True
if step == max_steps:
print("training stops")
exit()
step += 1
......@@ -37,7 +37,7 @@ def ray_resampling(
resampled_t_starts,
resampled_t_ends,
) = _C.ray_resampling(
packed_info.contiguous(),
packed_info.contiguous().int(),
t_starts.contiguous(),
t_ends.contiguous(),
weights.contiguous(),
......
......@@ -94,7 +94,7 @@ __global__ void cdf_resampling_kernel(
const int *packed_info, // input ray & point indices.
const scalar_t *starts, // input start t
const scalar_t *ends, // input end t
const scalar_t *weights, // transmittance weights
const scalar_t *w, // transmittance weights
const int *resample_packed_info,
scalar_t *resample_starts,
scalar_t *resample_ends)
......@@ -111,25 +111,26 @@ __global__ void cdf_resampling_kernel(
starts += base;
ends += base;
weights += base;
w += base;
resample_starts += resample_base;
resample_ends += resample_base;
// normalize weights **per ray**
scalar_t weights_sum = 0.0f;
scalar_t w_sum = 0.0f;
for (int j = 0; j < steps; j++)
weights_sum += weights[j];
scalar_t padding = fmaxf(1e-5f - weights_sum, 0.0f);
scalar_t padding_step = padding / steps;
weights_sum += padding;
w_sum += w[j];
// scalar_t padding = fmaxf(1e-10f - weights_sum, 0.0f);
// scalar_t padding_step = padding / steps;
// weights_sum += padding;
int num_bins = resample_steps + 1;
scalar_t cdf_step_size = (1.0f - 1.0 / num_bins) / resample_steps;
int num_endpoints = resample_steps + 1;
scalar_t cdf_pad = 1.0f / (2 * num_endpoints);
scalar_t cdf_step_size = (1.0f - 2 * cdf_pad) / resample_steps;
int idx = 0, j = 0;
scalar_t cdf_prev = 0.0f, cdf_next = (weights[idx] + padding_step) / weights_sum;
scalar_t cdf_u = 1.0 / (2 * num_bins);
while (j < num_bins)
scalar_t cdf_prev = 0.0f, cdf_next = w[idx] / w_sum;
scalar_t cdf_u = cdf_pad;
while (j < num_endpoints)
{
if (cdf_u < cdf_next)
{
......@@ -137,26 +138,32 @@ __global__ void cdf_resampling_kernel(
// resample in this interval
scalar_t scaling = (ends[idx] - starts[idx]) / (cdf_next - cdf_prev);
scalar_t t = (cdf_u - cdf_prev) * scaling + starts[idx];
if (j < num_bins - 1)
// if (j == 100) {
// printf(
// "cdf_u: %.10f, cdf_next: %.10f, cdf_prev: %.10f, scaling: %.10f, t: %.10f, starts[idx]: %.10f, ends[idx]: %.10f\n",
// cdf_u, cdf_next, cdf_prev, scaling, t, starts[idx], ends[idx]);
// }
if (j < num_endpoints - 1)
resample_starts[j] = t;
if (j > 0)
resample_ends[j - 1] = t;
// going further to next resample
cdf_u += cdf_step_size;
// cdf_u += cdf_step_size;
j += 1;
cdf_u = j * cdf_step_size + cdf_pad;
}
else
{
// going to next interval
idx += 1;
cdf_prev = cdf_next;
cdf_next += (weights[idx] + padding_step) / weights_sum;
cdf_next += w[idx] / w_sum;
}
}
if (j != num_bins)
{
printf("Error: %d %d %f\n", j, num_bins, weights_sum);
}
// if (j != num_endpoints)
// {
// printf("Error: %d %d %f\n", j, num_endpoints, weights_sum);
// }
return;
}
......
......@@ -4,10 +4,12 @@ import torch
import nerfacc.cuda as _C
from .cdf import ray_resampling
from .contraction import ContractionType
from .grid import Grid
from .intersection import ray_aabb_intersect
from .vol_rendering import render_visibility
from .pack import unpack_info
from .vol_rendering import render_visibility, render_weight_from_density
@torch.no_grad()
......@@ -24,6 +26,7 @@ def ray_marching(
# sigma/alpha function for skipping invisible space
sigma_fn: Optional[Callable] = None,
alpha_fn: Optional[Callable] = None,
proposal_nets: Optional[torch.nn.Module] = None,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
# rendering options
......@@ -189,6 +192,23 @@ def ray_marching(
cone_angle,
)
if proposal_nets is not None:
proposal_sample_list = []
# resample with proposal nets
for net, num_samples in zip(proposal_nets, [48]):
ray_indices = unpack_info(packed_info)
with torch.enable_grad():
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long(), net=net)
weights = render_weight_from_density(
packed_info, t_starts, t_ends, sigmas, early_stop_eps=0
)
proposal_sample_list.append(
(packed_info, t_starts, t_ends, weights)
)
packed_info, t_starts, t_ends = ray_resampling(
packed_info, t_starts, t_ends, weights, n_samples=num_samples
)
# skip invisible space
if sigma_fn is not None or alpha_fn is not None:
# Query sigma without gradients
......@@ -218,4 +238,7 @@ def ray_marching(
t_ends[masks],
)
return ray_indices, t_starts, t_ends
if proposal_nets is not None:
return packed_info, t_starts, t_ends, proposal_sample_list
else:
return packed_info, t_starts, t_ends
import pytest
import torch
from functorch import vmap
from nerfacc import pack_info, ray_marching, ray_resampling
from nerfacc.cuda import ray_pdf_query
device = "cuda:0"
batch_size = 128
eps = torch.finfo(torch.float32).eps
def _interp(x, xp, fp):
"""One-dimensional linear interpolation for monotonically increasing sample
points.
Returns the one-dimensional piecewise linear interpolant to a function with
given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`.
Args:
x: the :math:`x`-coordinates at which to evaluate the interpolated
values.
xp: the :math:`x`-coordinates of the data points, must be increasing.
fp: the :math:`y`-coordinates of the data points, same length as `xp`.
Returns:
the interpolated values, same size as `x`.
"""
xp = xp.contiguous()
x = x.contiguous()
m = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1])
b = fp[:-1] - (m * xp[:-1])
indices = torch.searchsorted(xp, x, right=True) - 1
indices = torch.clamp(indices, 0, len(m) - 1)
return m[indices] * x + b[indices]
def _integrate_weights(w):
"""Compute the cumulative sum of w, assuming all weight vectors sum to 1.
The output's size on the last dimension is one greater than that of the input,
because we're computing the integral corresponding to the endpoints of a step
function, not the integral of the interior/bin values.
Args:
w: Tensor, which will be integrated along the last axis. This is assumed to
sum to 1 along the last axis, and this function will (silently) break if
that is not the case.
Returns:
cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
"""
cw = torch.clamp(torch.cumsum(w[..., :-1], dim=-1), max=1)
shape = cw.shape[:-1] + (1,)
# Ensure that the CDF starts with exactly 0 and ends with exactly 1.
zeros = torch.zeros(shape, device=w.device)
ones = torch.ones(shape, device=w.device)
cw0 = torch.cat([zeros, cw, ones], dim=-1)
return cw0
def _invert_cdf(u, t, w_logits):
"""Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
# Compute the PDF and CDF for each weight vector.
w = torch.softmax(w_logits, dim=-1)
# w = torch.exp(w_logits)
# w = w / torch.sum(w, dim=-1, keepdim=True)
cw = _integrate_weights(w)
# Interpolate into the inverse CDF.
t_new = vmap(_interp)(u, cw, t)
return t_new
def _resampling(t, w_logits, num_samples):
"""Piecewise-Constant PDF sampling from a step function.
Args:
t: [..., num_bins + 1], bin endpoint coordinates (must be sorted).
w_logits: [..., num_bins], logits corresponding to bin weights.
num_samples: int, the number of samples.
returns:
t_samples: [..., num_samples], the sampled t values
"""
pad = 1 / (2 * num_samples)
u = torch.linspace(pad, 1.0 - pad - eps, num_samples, device=device)
u = torch.broadcast_to(u, t.shape[:-1] + (num_samples,))
return _invert_cdf(u, t, w_logits)
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_resampling():
rays_o = torch.rand((batch_size, 3), device=device)
rays_d = torch.randn((batch_size, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
batch_size = 1024
num_bins = 128
num_samples = 128
ray_indices, t_starts, t_ends = ray_marching(
rays_o,
rays_d,
near_plane=0.1,
far_plane=1.0,
render_step_size=1e-3,
t = torch.randn((batch_size, num_bins + 1), device=device)
t = torch.sort(t, dim=-1).values
w_logits = torch.randn((batch_size, num_bins), device=device) * 0.1
w = torch.softmax(w_logits, dim=-1)
masks = w_logits > 0
w_logits[~masks] = -torch.inf
t_samples = _resampling(t, w_logits, num_samples + 1)
t_starts = t[:, :-1][masks].unsqueeze(-1)
t_ends = t[:, 1:][masks].unsqueeze(-1)
w_logits = w_logits[masks]
w = w[masks]
num_steps = masks.long().sum(dim=-1)
cum_steps = torch.cumsum(num_steps, dim=0)
packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1).int()
_, t_starts, t_ends = ray_resampling(
packed_info, t_starts, t_ends, w, num_samples
)
packed_info = pack_info(ray_indices, n_rays=batch_size)
weights = torch.rand((t_starts.shape[0],), device=device)
packed_info, t_starts, t_ends = ray_resampling(
packed_info, t_starts, t_ends, weights, n_samples=32
# print(
# (t_starts.view(batch_size, num_samples) - t_samples[:, :-1])
# .abs()
# .max(),
# (t_ends.view(batch_size, num_samples) - t_samples[:, 1:]).abs().max(),
# )
assert torch.allclose(
t_starts.view(batch_size, num_samples), t_samples[:, :-1], atol=1e-3
)
assert torch.allclose(
t_ends.view(batch_size, num_samples), t_samples[:, 1:], atol=1e-3
)
assert t_starts.shape == t_ends.shape == (batch_size * 32, 1)
def test_pdf_query():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment