import argparse import math import os import time import imageio import numpy as np import torch import torch.nn.functional as F import tqdm from datasets.dnerf_synthetic import SubjectLoader from radiance_fields.mlp import DNeRFRadianceField from utils import render_image, set_random_seed from nerfacc import ContractionType, OccupancyGrid if __name__ == "__main__": device = "cuda:0" set_random_seed(42) parser = argparse.ArgumentParser() parser.add_argument( "--train_split", type=str, default="train", choices=["train"], help="which train split to use", ) parser.add_argument( "--scene", type=str, default="lego", choices=[ # dnerf "bouncingballs", "hellwarrior", "hook", "jumpingjacks", "lego", "mutant", "standup", "trex", ], help="which scene to use", ) parser.add_argument( "--aabb", type=lambda s: [float(item) for item in s.split(",")], default="-1.5,-1.5,-1.5,1.5,1.5,1.5", help="delimited list input", ) parser.add_argument( "--test_chunk_size", type=int, default=8192, ) parser.add_argument("--cone_angle", type=float, default=0.0) args = parser.parse_args() render_n_samples = 1024 # setup the scene bounding box. contraction_type = ContractionType.AABB scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device) near_plane = None far_plane = None render_step_size = ( (scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples ).item() # setup the radiance field we want to train. max_steps = 50000 grad_scaler = torch.cuda.amp.GradScaler(1) radiance_field = DNeRFRadianceField().to(device) optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[ max_steps // 2, max_steps * 3 // 4, max_steps * 5 // 6, max_steps * 9 // 10, ], gamma=0.33, ) # setup the dataset data_root_fp = "/home/ruilongli/data/dnerf/" target_sample_batch_size = 1 << 16 grid_resolution = 128 train_dataset = SubjectLoader( subject_id=args.scene, root_fp=data_root_fp, split=args.train_split, num_rays=target_sample_batch_size // render_n_samples, ) train_dataset.images = train_dataset.images.to(device) train_dataset.camtoworlds = train_dataset.camtoworlds.to(device) train_dataset.K = train_dataset.K.to(device) train_dataset.timestamps = train_dataset.timestamps.to(device) test_dataset = SubjectLoader( subject_id=args.scene, root_fp=data_root_fp, split="test", num_rays=None, ) test_dataset.images = test_dataset.images.to(device) test_dataset.camtoworlds = test_dataset.camtoworlds.to(device) test_dataset.K = test_dataset.K.to(device) test_dataset.timestamps = test_dataset.timestamps.to(device) occupancy_grid = OccupancyGrid( roi_aabb=args.aabb, resolution=grid_resolution, contraction_type=contraction_type, ).to(device) # training step = 0 tic = time.time() for epoch in range(10000000): for i in range(len(train_dataset)): radiance_field.train() data = train_dataset[i] render_bkgd = data["color_bkgd"] rays = data["rays"] pixels = data["pixels"] timestamps = data["timestamps"] # update occupancy grid occupancy_grid.every_n_step( step=step, occ_eval_fn=lambda x: radiance_field.query_opacity( x, timestamps, render_step_size ), ) # render rgb, acc, depth, n_rendering_samples = render_image( radiance_field, occupancy_grid, rays, scene_aabb, # rendering options near_plane=near_plane, far_plane=far_plane, render_step_size=render_step_size, render_bkgd=render_bkgd, cone_angle=args.cone_angle, # dnerf options timestamps=timestamps, ) # dynamic batch size for rays to keep sample batch size constant. num_rays = len(pixels) num_rays = int( num_rays * (target_sample_batch_size / float(n_rendering_samples)) ) train_dataset.update_num_rays(num_rays) alive_ray_mask = acc.squeeze(-1) > 0 # compute loss loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) optimizer.zero_grad() # do not unscale it because we are using Adam. grad_scaler.scale(loss).backward() optimizer.step() scheduler.step() if step % 5000 == 0: elapsed_time = time.time() - tic loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) print( f"elapsed_time={elapsed_time:.2f}s | {step=} | " f"loss={loss:.5f} | " f"alive_ray_mask={alive_ray_mask.long().sum():d} | " f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |" ) if step >= 0 and step % max_steps == 0 and step > 0: # evaluation radiance_field.eval() psnrs = [] with torch.no_grad(): for i in tqdm.tqdm(range(len(test_dataset))): data = test_dataset[i] render_bkgd = data["color_bkgd"] rays = data["rays"] pixels = data["pixels"] timestamps = data["timestamps"] # rendering rgb, acc, depth, _ = render_image( radiance_field, occupancy_grid, rays, scene_aabb, # rendering options near_plane=None, far_plane=None, render_step_size=render_step_size, render_bkgd=render_bkgd, cone_angle=args.cone_angle, # test options test_chunk_size=args.test_chunk_size, # dnerf options timestamps=timestamps, ) mse = F.mse_loss(rgb, pixels) psnr = -10.0 * torch.log(mse) / np.log(10.0) psnrs.append(psnr.item()) # imageio.imwrite( # "acc_binary_test.png", # ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8), # ) # imageio.imwrite( # "rgb_test.png", # (rgb.cpu().numpy() * 255).astype(np.uint8), # ) # break psnr_avg = sum(psnrs) / len(psnrs) print(f"evaluation: {psnr_avg=}") train_dataset.training = True if step == max_steps: print("training stops") exit() step += 1