""" Copyright (c) 2022 Ruilong Li, UC Berkeley. """ import random from typing import Optional import numpy as np import torch from datasets.utils import Rays, namedtuple_map from nerfacc import OccupancyGrid, ray_marching, rendering def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) def render_image( # scene radiance_field: torch.nn.Module, occupancy_grid: OccupancyGrid, rays: Rays, scene_aabb: torch.Tensor, # rendering options near_plane: Optional[float] = None, far_plane: Optional[float] = None, render_step_size: float = 1e-3, render_bkgd: Optional[torch.Tensor] = None, cone_angle: float = 0.0, alpha_thre: float = 0.0, # test options test_chunk_size: int = 8192, # only useful for dnerf timestamps: Optional[torch.Tensor] = None, ): """Render the pixels of an image.""" rays_shape = rays.origins.shape if len(rays_shape) == 3: height, width, _ = rays_shape num_rays = height * width rays = namedtuple_map( lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays ) else: num_rays, _ = rays_shape def sigma_fn(t_starts, t_ends, ray_indices): t_origins = chunk_rays.origins[ray_indices] t_dirs = chunk_rays.viewdirs[ray_indices] positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0 if timestamps is not None: # dnerf t = ( timestamps[ray_indices] if radiance_field.training else timestamps.expand_as(positions[:, :1]) ) return radiance_field.query_density(positions, t) return radiance_field.query_density(positions) def rgb_sigma_fn(t_starts, t_ends, ray_indices): t_origins = chunk_rays.origins[ray_indices] t_dirs = chunk_rays.viewdirs[ray_indices] positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0 if timestamps is not None: # dnerf t = ( timestamps[ray_indices] if radiance_field.training else timestamps.expand_as(positions[:, :1]) ) return radiance_field(positions, t, t_dirs) return radiance_field(positions, t_dirs) results = [] chunk = ( torch.iinfo(torch.int32).max if radiance_field.training else test_chunk_size ) for i in range(0, num_rays, chunk): chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays) ray_indices, t_starts, t_ends = ray_marching( chunk_rays.origins, chunk_rays.viewdirs, scene_aabb=scene_aabb, grid=occupancy_grid, sigma_fn=sigma_fn, near_plane=near_plane, far_plane=far_plane, render_step_size=render_step_size, stratified=radiance_field.training, cone_angle=cone_angle, alpha_thre=alpha_thre, ) rgb, opacity, depth = rendering( t_starts, t_ends, ray_indices, n_rays=chunk_rays.origins.shape[0], rgb_sigma_fn=rgb_sigma_fn, render_bkgd=render_bkgd, ) chunk_results = [rgb, opacity, depth, len(t_starts)] results.append(chunk_results) colors, opacities, depths, n_rendering_samples = [ torch.cat(r, dim=0) if isinstance(r[0], torch.Tensor) else r for r in zip(*results) ] return ( colors.view((*rays_shape[:-1], -1)), opacities.view((*rays_shape[:-1], -1)), depths.view((*rays_shape[:-1], -1)), sum(n_rendering_samples), )