""" Copyright (c) 2022 Ruilong Li, UC Berkeley. """ import argparse import math import pathlib import time import imageio import numpy as np import torch import torch.nn.functional as F import tqdm from datasets.dnerf_synthetic import SubjectLoader from radiance_fields.mlp import DNeRFRadianceField from utils import render_image, set_random_seed from nerfacc import OccupancyGrid device = "cuda:0" set_random_seed(42) parser = argparse.ArgumentParser() parser.add_argument( "--data_root", type=str, default=str(pathlib.Path.cwd() / "data/dnerf"), help="the root dir of the dataset", ) parser.add_argument( "--train_split", type=str, default="train", choices=["train"], help="which train split to use", ) parser.add_argument( "--scene", type=str, default="lego", choices=[ # dnerf "bouncingballs", "hellwarrior", "hook", "jumpingjacks", "lego", "mutant", "standup", "trex", ], help="which scene to use", ) parser.add_argument( "--aabb", type=lambda s: [float(item) for item in s.split(",")], default="-1.5,-1.5,-1.5,1.5,1.5,1.5", help="delimited list input", ) parser.add_argument( "--test_chunk_size", type=int, default=8192, ) parser.add_argument("--cone_angle", type=float, default=0.0) args = parser.parse_args() render_n_samples = 1024 # setup the scene bounding box. scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device) near_plane = None far_plane = None render_step_size = ( (scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples ).item() # setup the radiance field we want to train. max_steps = 30000 grad_scaler = torch.cuda.amp.GradScaler(1) radiance_field = DNeRFRadianceField().to(device) optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[ max_steps // 2, max_steps * 3 // 4, max_steps * 5 // 6, max_steps * 9 // 10, ], gamma=0.33, ) # setup the dataset target_sample_batch_size = 1 << 16 grid_resolution = 128 train_dataset = SubjectLoader( subject_id=args.scene, root_fp=args.data_root, split=args.train_split, num_rays=target_sample_batch_size // render_n_samples, device=device, ) test_dataset = SubjectLoader( subject_id=args.scene, root_fp=args.data_root, split="test", num_rays=None, device=device, ) occupancy_grid = OccupancyGrid( roi_aabb=args.aabb, resolution=grid_resolution ).to(device) # training step = 0 tic = time.time() for epoch in range(10000000): for i in range(len(train_dataset)): radiance_field.train() data = train_dataset[i] render_bkgd = data["color_bkgd"] rays = data["rays"] pixels = data["pixels"] timestamps = data["timestamps"] # update occupancy grid occupancy_grid.every_n_step( step=step, occ_eval_fn=lambda x: radiance_field.query_opacity( x, timestamps, render_step_size ), ) # render rgb, acc, depth, n_rendering_samples = render_image( radiance_field, occupancy_grid, rays, scene_aabb, # rendering options near_plane=near_plane, far_plane=far_plane, render_step_size=render_step_size, render_bkgd=render_bkgd, cone_angle=args.cone_angle, alpha_thre=0.01 if step > 1000 else 0.00, # dnerf options timestamps=timestamps, ) if n_rendering_samples == 0: continue # dynamic batch size for rays to keep sample batch size constant. num_rays = len(pixels) num_rays = int( num_rays * (target_sample_batch_size / float(n_rendering_samples)) ) train_dataset.update_num_rays(num_rays) alive_ray_mask = acc.squeeze(-1) > 0 # compute loss loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) optimizer.zero_grad() # do not unscale it because we are using Adam. grad_scaler.scale(loss).backward() optimizer.step() scheduler.step() if step % 5000 == 0: elapsed_time = time.time() - tic loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) print( f"elapsed_time={elapsed_time:.2f}s | step={step} | " f"loss={loss:.5f} | " f"alive_ray_mask={alive_ray_mask.long().sum():d} | " f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |" ) if step > 0 and step % max_steps == 0: # evaluation radiance_field.eval() psnrs = [] with torch.no_grad(): for i in tqdm.tqdm(range(len(test_dataset))): data = test_dataset[i] render_bkgd = data["color_bkgd"] rays = data["rays"] pixels = data["pixels"] timestamps = data["timestamps"] # rendering rgb, acc, depth, _ = render_image( radiance_field, occupancy_grid, rays, scene_aabb, # rendering options near_plane=None, far_plane=None, render_step_size=render_step_size, render_bkgd=render_bkgd, cone_angle=args.cone_angle, alpha_thre=0.01, # test options test_chunk_size=args.test_chunk_size, # dnerf options timestamps=timestamps, ) mse = F.mse_loss(rgb, pixels) psnr = -10.0 * torch.log(mse) / np.log(10.0) psnrs.append(psnr.item()) # imageio.imwrite( # "acc_binary_test.png", # ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8), # ) # imageio.imwrite( # "rgb_test.png", # (rgb.cpu().numpy() * 255).astype(np.uint8), # ) # break psnr_avg = sum(psnrs) / len(psnrs) print(f"evaluation: psnr_avg={psnr_avg}") train_dataset.training = True if step == max_steps: print("training stops") exit() step += 1