Unverified Commit 34855fdc authored by Youtian Lin's avatar Youtian Lin Committed by GitHub
Browse files

Include a test mode rendering function for Instant NGP in the examples (#217)

* add test mode render_image_with_occgrid

* fix format

* misc fix
parent d3f1e37f
...@@ -19,6 +19,7 @@ from examples.utils import ( ...@@ -19,6 +19,7 @@ from examples.utils import (
MIPNERF360_UNBOUNDED_SCENES, MIPNERF360_UNBOUNDED_SCENES,
NERF_SYNTHETIC_SCENES, NERF_SYNTHETIC_SCENES,
render_image_with_occgrid, render_image_with_occgrid,
render_image_with_occgrid_test,
set_random_seed, set_random_seed,
) )
from nerfacc.estimators.occ_grid import OccGridEstimator from nerfacc.estimators.occ_grid import OccGridEstimator
...@@ -45,11 +46,6 @@ parser.add_argument( ...@@ -45,11 +46,6 @@ parser.add_argument(
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES, choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
help="which scene to use", help="which scene to use",
) )
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
)
args = parser.parse_args() args = parser.parse_args()
device = "cuda:0" device = "cuda:0"
...@@ -233,7 +229,9 @@ for step in range(max_steps + 1): ...@@ -233,7 +229,9 @@ for step in range(max_steps + 1):
pixels = data["pixels"] pixels = data["pixels"]
# rendering # rendering
rgb, acc, depth, _ = render_image_with_occgrid( rgb, acc, depth, _ = render_image_with_occgrid_test(
1024,
# scene
radiance_field, radiance_field,
estimator, estimator,
rays, rays,
...@@ -243,8 +241,6 @@ for step in range(max_steps + 1): ...@@ -243,8 +241,6 @@ for step in range(max_steps + 1):
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
cone_angle=cone_angle, cone_angle=cone_angle,
alpha_thre=alpha_thre, alpha_thre=alpha_thre,
# test options
test_chunk_size=args.test_chunk_size,
) )
mse = F.mse_loss(rgb, pixels) mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0) psnr = -10.0 * torch.log(mse) / np.log(10.0)
......
...@@ -17,7 +17,12 @@ from torch.utils.data._utils.collate import collate, default_collate_fn_map ...@@ -17,7 +17,12 @@ from torch.utils.data._utils.collate import collate, default_collate_fn_map
from nerfacc.estimators.occ_grid import OccGridEstimator from nerfacc.estimators.occ_grid import OccGridEstimator
from nerfacc.estimators.prop_net import PropNetEstimator from nerfacc.estimators.prop_net import PropNetEstimator
from nerfacc.volrend import rendering from nerfacc.grid import ray_aabb_intersect, traverse_grids
from nerfacc.volrend import (
accumulate_along_rays_,
render_weight_from_density,
rendering,
)
NERF_SYNTHETIC_SCENES = [ NERF_SYNTHETIC_SCENES = [
"chair", "chair",
...@@ -242,3 +247,179 @@ def render_image_with_propnet( ...@@ -242,3 +247,179 @@ def render_image_with_propnet(
depths.view((*rays_shape[:-1], -1)), depths.view((*rays_shape[:-1], -1)),
extras, extras,
) )
@torch.no_grad()
def render_image_with_occgrid_test(
max_samples: int,
# scene
radiance_field: torch.nn.Module,
estimator: OccGridEstimator,
rays: Rays,
# rendering options
near_plane: float = 0.0,
far_plane: float = 1e10,
render_step_size: float = 1e-3,
render_bkgd: Optional[torch.Tensor] = None,
cone_angle: float = 0.0,
alpha_thre: float = 0.0,
early_stop_eps: float = 1e-4,
# only useful for dnerf
timestamps: Optional[torch.Tensor] = None,
):
"""Render the pixels of an image."""
rays_shape = rays.origins.shape
if len(rays_shape) == 3:
height, width, _ = rays_shape
num_rays = height * width
rays = namedtuple_map(
lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays
)
else:
num_rays, _ = rays_shape
def rgb_sigma_fn(t_starts, t_ends, ray_indices):
t_origins = rays.origins[ray_indices]
t_dirs = rays.viewdirs[ray_indices]
positions = (
t_origins + t_dirs * (t_starts[:, None] + t_ends[:, None]) / 2.0
)
if timestamps is not None:
# dnerf
t = (
timestamps[ray_indices]
if radiance_field.training
else timestamps.expand_as(positions[:, :1])
)
rgbs, sigmas = radiance_field(positions, t, t_dirs)
else:
rgbs, sigmas = radiance_field(positions, t_dirs)
return rgbs, sigmas.squeeze(-1)
device = rays.origins.device
opacity = torch.zeros(num_rays, 1, device=device)
depth = torch.zeros(num_rays, 1, device=device)
rgb = torch.zeros(num_rays, 3, device=device)
ray_mask = torch.ones(num_rays, device=device).bool()
# 1 for synthetic scenes, 4 for real scenes
min_samples = 1 if cone_angle == 0 else 4
iter_samples = total_samples = 0
rays_o = rays.origins
rays_d = rays.viewdirs
near_planes = torch.full_like(rays_o[..., 0], fill_value=near_plane)
far_planes = torch.full_like(rays_o[..., 0], fill_value=far_plane)
t_mins, t_maxs, hits = ray_aabb_intersect(rays_o, rays_d, estimator.aabbs)
n_grids = estimator.binaries.size(0)
if n_grids > 1:
t_sorted, t_indices = torch.sort(torch.cat([t_mins, t_maxs], -1), -1)
else:
t_sorted = torch.cat([t_mins, t_maxs], -1)
t_indices = torch.arange(
0, n_grids * 2, device=t_mins.device, dtype=torch.int64
).expand(num_rays, n_grids * 2)
opc_thre = 1 - early_stop_eps
while iter_samples < max_samples:
n_alive = ray_mask.sum().item()
if n_alive == 0:
break
# the number of samples to add on each ray
n_samples = max(min(num_rays // n_alive, 64), min_samples)
iter_samples += n_samples
# ray marching
(intervals, samples, termination_planes) = traverse_grids(
# rays
rays_o, # [n_rays, 3]
rays_d, # [n_rays, 3]
# grids
estimator.binaries, # [m, resx, resy, resz]
estimator.aabbs, # [m, 6]
# options
near_planes, # [n_rays]
far_planes, # [n_rays]
render_step_size,
cone_angle,
n_samples,
True,
ray_mask,
# pre-compute intersections
t_sorted, # [n_rays, m*2]
t_indices, # [n_rays, m*2]
hits, # [n_rays, m]
)
t_starts = intervals.vals[intervals.is_left]
t_ends = intervals.vals[intervals.is_right]
ray_indices = samples.ray_indices[samples.is_valid]
packed_info = samples.packed_info
# get rgb and sigma from radiance field
rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices)
# volume rendering using native cuda scan
weights, _, alphas = render_weight_from_density(
t_starts,
t_ends,
sigmas,
ray_indices=ray_indices,
n_rays=num_rays,
prefix_trans=1 - opacity[ray_indices].squeeze(-1),
)
if alpha_thre > 0:
vis_mask = alphas >= alpha_thre
ray_indices, rgbs, weights, t_starts, t_ends = (
ray_indices[vis_mask],
rgbs[vis_mask],
weights[vis_mask],
t_starts[vis_mask],
t_ends[vis_mask],
)
accumulate_along_rays_(
weights,
values=rgbs,
ray_indices=ray_indices,
outputs=rgb,
)
accumulate_along_rays_(
weights,
values=None,
ray_indices=ray_indices,
outputs=opacity,
)
accumulate_along_rays_(
weights,
values=(t_starts + t_ends)[..., None] / 2.0,
ray_indices=ray_indices,
outputs=depth,
)
# update near_planes using termination planes
near_planes = termination_planes
# update rays status
ray_mask = torch.logical_and(
# early stopping
opacity.view(-1) <= opc_thre,
# remove rays that have reached the far plane
packed_info[:, 1] == n_samples,
)
total_samples += ray_indices.shape[0]
rgb = rgb + render_bkgd * (1.0 - opacity)
depth = depth / opacity.clamp_min(torch.finfo(rgbs.dtype).eps)
return (
rgb.view((*rays_shape[:-1], -1)),
opacity.view((*rays_shape[:-1], -1)),
depth.view((*rays_shape[:-1], -1)),
total_samples,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment