""" Copyright (c) 2022 Ruilong Li, UC Berkeley. """ import argparse import math import os import random import time from typing import Optional import imageio import numpy as np import torch import torch.nn.functional as F import tqdm from datasets.utils import Rays, namedtuple_map from radiance_fields.ngp import NGPradianceField from utils import set_random_seed from nerfacc import ContractionType, pack_info, ray_marching, rendering from nerfacc.cuda import ray_pdf_query def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # @profile def render_image( # scene radiance_field: torch.nn.Module, proposal_nets: torch.nn.Module, rays: Rays, scene_aabb: torch.Tensor, # rendering options near_plane: Optional[float] = None, far_plane: Optional[float] = None, render_step_size: float = 1e-3, render_bkgd: Optional[torch.Tensor] = None, cone_angle: float = 0.0, alpha_thre: float = 0.0, proposal_nets_require_grads: bool = True, # test options test_chunk_size: int = 8192, ): """Render the pixels of an image.""" rays_shape = rays.origins.shape if len(rays_shape) == 3: height, width, _ = rays_shape num_rays = height * width rays = namedtuple_map( lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays ) else: num_rays, _ = rays_shape def sigma_fn(t_starts, t_ends, ray_indices, net=None): ray_indices = ray_indices.long() t_origins = chunk_rays.origins[ray_indices] t_dirs = chunk_rays.viewdirs[ray_indices] positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0 if net is not None: return net.query_density(positions) else: return radiance_field.query_density(positions) def rgb_sigma_fn(t_starts, t_ends, ray_indices): ray_indices = ray_indices.long() t_origins = chunk_rays.origins[ray_indices] t_dirs = chunk_rays.viewdirs[ray_indices] positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0 return radiance_field(positions, t_dirs) results = [] chunk = ( torch.iinfo(torch.int32).max if radiance_field.training else test_chunk_size ) for i in range(0, num_rays, chunk): chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays) ray_indices, t_starts, t_ends, proposal_sample_list = ray_marching( chunk_rays.origins, chunk_rays.viewdirs, scene_aabb=scene_aabb, grid=None, # proposal density fns: {t_starts, t_ends, ray_indices} -> density proposal_sigma_fns=[ lambda t_starts, t_ends, ray_indices: sigma_fn( t_starts, t_ends, ray_indices, proposal_net ) for proposal_net in proposal_nets ], proposal_n_samples=[32], proposal_require_grads=proposal_nets_require_grads, sigma_fn=sigma_fn, near_plane=near_plane, far_plane=far_plane, render_step_size=render_step_size, stratified=radiance_field.training, cone_angle=cone_angle, alpha_thre=alpha_thre, ) rgb, opacity, depth, weights = rendering( t_starts, t_ends, ray_indices=ray_indices, n_rays=len(chunk_rays.origins), rgb_sigma_fn=rgb_sigma_fn, render_bkgd=render_bkgd, ) if radiance_field.training: packed_info = pack_info(ray_indices, n_rays=len(chunk_rays.origins)) proposal_sample_list.append( (packed_info, t_starts, t_ends, weights) ) chunk_results = [rgb, opacity, depth, len(t_starts)] results.append(chunk_results) colors, opacities, depths, n_rendering_samples = [ torch.cat(r, dim=0) if isinstance(r[0], torch.Tensor) else r for r in zip(*results) ] return ( colors.view((*rays_shape[:-1], -1)), opacities.view((*rays_shape[:-1], -1)), depths.view((*rays_shape[:-1], -1)), sum(n_rendering_samples), proposal_sample_list if radiance_field.training else None, ) if __name__ == "__main__": device = "cuda:0" set_random_seed(42) parser = argparse.ArgumentParser() parser.add_argument( "--train_split", type=str, default="trainval", choices=["train", "trainval"], help="which train split to use", ) parser.add_argument( "--scene", type=str, default="lego", choices=[ # nerf synthetic "chair", "drums", "ficus", "hotdog", "lego", "materials", "mic", "ship", # mipnerf360 unbounded "garden", "bicycle", "bonsai", "counter", "kitchen", "room", "stump", ], help="which scene to use", ) parser.add_argument( "--aabb", type=lambda s: [float(item) for item in s.split(",")], default="-1.5,-1.5,-1.5,1.5,1.5,1.5", help="delimited list input", ) parser.add_argument( "--test_chunk_size", type=int, default=1024, ) parser.add_argument( "--unbounded", action="store_true", help="whether to use unbounded rendering", ) parser.add_argument( "--auto_aabb", action="store_true", help="whether to automatically compute the aabb", ) parser.add_argument("--cone_angle", type=float, default=0.0) args = parser.parse_args() render_n_samples = 256 # setup the dataset train_dataset_kwargs = {} test_dataset_kwargs = {} if args.unbounded: from datasets.nerf_360_v2 import SubjectLoader data_root_fp = "/home/ruilongli/data/360_v2/" target_sample_batch_size = 1 << 20 train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4} test_dataset_kwargs = {"factor": 4} else: from datasets.nerf_synthetic import SubjectLoader data_root_fp = "/home/ruilongli/data/nerf_synthetic/" target_sample_batch_size = 1 << 18 train_dataset = SubjectLoader( subject_id=args.scene, root_fp=data_root_fp, split=args.train_split, num_rays=target_sample_batch_size // 32, **train_dataset_kwargs, ) train_dataset.images = train_dataset.images.to(device) train_dataset.camtoworlds = train_dataset.camtoworlds.to(device) train_dataset.K = train_dataset.K.to(device) test_dataset = SubjectLoader( subject_id=args.scene, root_fp=data_root_fp, split="test", num_rays=None, **test_dataset_kwargs, ) test_dataset.images = test_dataset.images.to(device) test_dataset.camtoworlds = test_dataset.camtoworlds.to(device) test_dataset.K = test_dataset.K.to(device) if args.auto_aabb: camera_locs = torch.cat( [train_dataset.camtoworlds, test_dataset.camtoworlds] )[:, :3, -1] args.aabb = torch.cat( [camera_locs.min(dim=0).values, camera_locs.max(dim=0).values] ).tolist() print("Using auto aabb", args.aabb) # setup the scene bounding box. if args.unbounded: print("Using unbounded rendering") contraction_type = ContractionType.UN_BOUNDED_SPHERE # contraction_type = ContractionType.UN_BOUNDED_TANH scene_aabb = None near_plane = 0.2 far_plane = 1e4 render_step_size = 1e-2 alpha_thre = 1e-2 else: contraction_type = ContractionType.AABB scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device) near_plane = None far_plane = None render_step_size = ( (scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples ).item() alpha_thre = 1e-2 proposal_nets = torch.nn.ModuleList( [ NGPradianceField( aabb=args.aabb, use_viewdirs=False, hidden_dim=0, geo_feat_dim=0, ), # NGPradianceField( # aabb=args.aabb, # use_viewdirs=False, # hidden_dim=16, # max_res=64, # geo_feat_dim=0, # n_levels=4, # log2_hashmap_size=19, # ), # NGPradianceField( # aabb=args.aabb, # use_viewdirs=False, # hidden_dim=16, # max_res=256, # geo_feat_dim=0, # n_levels=5, # log2_hashmap_size=17, # ), ] ).to(device) # setup the radiance field we want to train. max_steps = 20000 grad_scaler = torch.cuda.amp.GradScaler(2**10) radiance_field = NGPradianceField( aabb=args.aabb, unbounded=args.unbounded, ).to(device) optimizer = torch.optim.Adam( list(radiance_field.parameters()) + list(proposal_nets.parameters()), lr=1e-2, eps=1e-15, ) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10], gamma=0.33, ) # training step = 0 tic = time.time() for epoch in range(10000000): for i in range(len(train_dataset)): radiance_field.train() proposal_nets.train() # @profile def _train(): data = train_dataset[i] render_bkgd = data["color_bkgd"] rays = data["rays"] pixels = data["pixels"] # render ( rgb, acc, depth, n_rendering_samples, proposal_sample_list, ) = render_image( radiance_field, proposal_nets, rays, scene_aabb, # rendering options near_plane=near_plane, far_plane=far_plane, render_step_size=render_step_size, render_bkgd=render_bkgd, cone_angle=args.cone_angle, alpha_thre=min(alpha_thre, alpha_thre * step / 1000), proposal_nets_require_grads=(step < 100 or step % 16 == 0), ) # if n_rendering_samples == 0: # continue # dynamic batch size for rays to keep sample batch size constant. num_rays = len(pixels) num_rays = int( num_rays * (target_sample_batch_size / float(n_rendering_samples)) ) train_dataset.update_num_rays(num_rays) alive_ray_mask = acc.squeeze(-1) > 0 # compute loss loss = F.smooth_l1_loss( rgb[alive_ray_mask], pixels[alive_ray_mask] ) ( packed_info, t_starts, t_ends, weights, ) = proposal_sample_list[-1] loss_interval = 0.0 for ( proposal_packed_info, proposal_t_starts, proposal_t_ends, proposal_weights, ) in proposal_sample_list[:-1]: proposal_weights_gt = ray_pdf_query( packed_info, t_starts, t_ends, weights.detach(), proposal_packed_info, proposal_t_starts, proposal_t_ends, ).detach() loss_interval = ( torch.clamp( proposal_weights_gt - proposal_weights, min=0 ) ) ** 2 / (proposal_weights + torch.finfo(torch.float32).eps) loss_interval = loss_interval.mean() loss += loss_interval * 1.0 optimizer.zero_grad() # do not unscale it because we are using Adam. grad_scaler.scale(loss).backward() optimizer.step() scheduler.step() if step % 100 == 0: elapsed_time = time.time() - tic loss = F.mse_loss( rgb[alive_ray_mask], pixels[alive_ray_mask] ) print( f"elapsed_time={elapsed_time:.2f}s | step={step} | " f"loss={loss:.5f} | loss_interval={loss_interval:.5f} " f"alive_ray_mask={alive_ray_mask.long().sum():d} | " f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |" ) _train() if step >= 0 and step % 1000 == 0 and step > 0: # evaluation radiance_field.eval() proposal_nets.eval() psnrs = [] with torch.no_grad(): for i in tqdm.tqdm(range(len(test_dataset))): data = test_dataset[i] render_bkgd = data["color_bkgd"] rays = data["rays"] pixels = data["pixels"] # rendering rgb, acc, depth, _, _ = render_image( radiance_field, proposal_nets, rays, scene_aabb, # rendering options near_plane=near_plane, far_plane=far_plane, render_step_size=render_step_size, render_bkgd=render_bkgd, cone_angle=args.cone_angle, alpha_thre=alpha_thre, proposal_nets_require_grads=False, # test options test_chunk_size=args.test_chunk_size, ) mse = F.mse_loss(rgb, pixels) psnr = -10.0 * torch.log(mse) / np.log(10.0) psnrs.append(psnr.item()) imageio.imwrite( "acc_binary_test.png", ((acc > 0).float().cpu().numpy() * 255).astype( np.uint8 ), ) imageio.imwrite( "rgb_test.png", (rgb.cpu().numpy() * 255).astype(np.uint8), ) break psnr_avg = sum(psnrs) / len(psnrs) print(f"evaluation: psnr_avg={psnr_avg}") train_dataset.training = True if step == max_steps: print("training stops") exit() step += 1