""" Copyright (c) 2022 Ruilong Li, UC Berkeley. """ import argparse import itertools import pathlib import time import imageio import numpy as np import torch import torch.nn.functional as F import tqdm from lpips import LPIPS from radiance_fields.ngp import NGPDensityField, NGPRadianceField from utils import ( MIPNERF360_UNBOUNDED_SCENES, NERF_SYNTHETIC_SCENES, render_image_proposal, set_random_seed, ) from nerfacc.proposal import ( compute_prop_loss, get_proposal_annealing_fn, get_proposal_requires_grad_fn, ) parser = argparse.ArgumentParser() parser.add_argument( "--data_root", type=str, # default=str(pathlib.Path.cwd() / "data/360_v2"), default=str(pathlib.Path.cwd() / "data/nerf_synthetic"), help="the root dir of the dataset", ) parser.add_argument( "--train_split", type=str, default="train", choices=["train", "trainval"], help="which train split to use", ) parser.add_argument( "--scene", type=str, default="lego", choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES, help="which scene to use", ) parser.add_argument( "--test_chunk_size", type=int, default=8192, ) args = parser.parse_args() device = "cuda:0" set_random_seed(42) if args.scene in MIPNERF360_UNBOUNDED_SCENES: from datasets.nerf_360_v2 import SubjectLoader # training parameters max_steps = 20000 init_batch_size = 4096 weight_decay = 0.0 # scene parameters unbounded = True aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device) near_plane = 0.2 # TODO: Try 0.02 far_plane = 1e3 # dataset parameters train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4} test_dataset_kwargs = {"factor": 4} # model parameters proposal_networks = [ NGPDensityField( aabb=aabb, unbounded=unbounded, n_levels=5, max_resolution=128, ).to(device), NGPDensityField( aabb=aabb, unbounded=unbounded, n_levels=5, max_resolution=256, ).to(device), ] # render parameters num_samples = 48 num_samples_per_prop = [256, 96] sampling_type = "lindisp" opaque_bkgd = True else: from datasets.nerf_synthetic import SubjectLoader # training parameters max_steps = 20000 init_batch_size = 4096 weight_decay = ( 1e-5 if args.scene in ["materials", "ficus", "drums"] else 1e-6 ) # scene parameters unbounded = False aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device) near_plane = 2.0 far_plane = 6.0 # dataset parameters train_dataset_kwargs = {} test_dataset_kwargs = {} # model parameters proposal_networks = [ NGPDensityField( aabb=aabb, unbounded=unbounded, n_levels=5, max_resolution=128, ).to(device), ] # render parameters num_samples = 64 num_samples_per_prop = [128] sampling_type = "uniform" opaque_bkgd = False train_dataset = SubjectLoader( subject_id=args.scene, root_fp=args.data_root, split=args.train_split, num_rays=init_batch_size, device=device, **train_dataset_kwargs, ) test_dataset = SubjectLoader( subject_id=args.scene, root_fp=args.data_root, split="test", num_rays=None, device=device, **test_dataset_kwargs, ) # setup the radiance field we want to train. grad_scaler = torch.cuda.amp.GradScaler(2**10) radiance_field = NGPRadianceField(aabb=aabb, unbounded=unbounded).to(device) optimizer = torch.optim.Adam( itertools.chain( radiance_field.parameters(), *[p.parameters() for p in proposal_networks], ), lr=1e-2, eps=1e-15, weight_decay=weight_decay, ) scheduler = torch.optim.lr_scheduler.ChainedScheduler( [ torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=0.01, total_iters=100 ), torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[ max_steps // 2, max_steps * 3 // 4, max_steps * 9 // 10, ], gamma=0.33, ), ] ) proposal_requires_grad_fn = get_proposal_requires_grad_fn() proposal_annealing_fn = get_proposal_annealing_fn() lpips_net = LPIPS(net="vgg").to(device) lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1 lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean() # training tic = time.time() for step in range(max_steps + 1): radiance_field.train() for p in proposal_networks: p.train() i = torch.randint(0, len(train_dataset), (1,)).item() data = train_dataset[i] render_bkgd = data["color_bkgd"] rays = data["rays"] pixels = data["pixels"] # render ( rgb, acc, depth, weights_per_level, s_vals_per_level, ) = render_image_proposal( radiance_field, proposal_networks, rays, scene_aabb=None, # rendering options num_samples=num_samples, num_samples_per_prop=num_samples_per_prop, near_plane=near_plane, far_plane=far_plane, sampling_type=sampling_type, opaque_bkgd=opaque_bkgd, render_bkgd=render_bkgd, # train options proposal_requires_grad=proposal_requires_grad_fn(step), proposal_annealing=proposal_annealing_fn(step), ) # compute loss loss = F.smooth_l1_loss(rgb, pixels) loss_prop = compute_prop_loss(s_vals_per_level, weights_per_level) loss = loss + loss_prop optimizer.zero_grad() # do not unscale it because we are using Adam. grad_scaler.scale(loss).backward() optimizer.step() scheduler.step() if step % 5000 == 0: elapsed_time = time.time() - tic loss = F.mse_loss(rgb, pixels) psnr = -10.0 * torch.log(loss) / np.log(10.0) print( f"elapsed_time={elapsed_time:.2f}s | step={step} | " f"loss={loss:.5f} | psnr={psnr:.2f} | " f"num_rays={len(pixels):d} | " f"max_depth={depth.max():.3f} | " ) if step > 0 and step % max_steps == 0: # evaluation radiance_field.eval() for p in proposal_networks: p.eval() psnrs = [] lpips = [] with torch.no_grad(): for i in tqdm.tqdm(range(len(test_dataset))): data = test_dataset[i] render_bkgd = data["color_bkgd"] rays = data["rays"] pixels = data["pixels"] # rendering rgb, acc, depth, _, _, = render_image_proposal( radiance_field, proposal_networks, rays, scene_aabb=None, # rendering options num_samples=num_samples, num_samples_per_prop=num_samples_per_prop, near_plane=near_plane, far_plane=far_plane, sampling_type=sampling_type, opaque_bkgd=opaque_bkgd, render_bkgd=render_bkgd, proposal_annealing=proposal_annealing_fn(step), # test options test_chunk_size=args.test_chunk_size, ) mse = F.mse_loss(rgb, pixels) psnr = -10.0 * torch.log(mse) / np.log(10.0) psnrs.append(psnr.item()) lpips.append(lpips_fn(rgb, pixels).item()) # if i == 0: # imageio.imwrite( # "rgb_test.png", # (rgb.cpu().numpy() * 255).astype(np.uint8), # ) # imageio.imwrite( # "rgb_error.png", # ( # (rgb - pixels).norm(dim=-1).cpu().numpy() * 255 # ).astype(np.uint8), # ) psnr_avg = sum(psnrs) / len(psnrs) lpips_avg = sum(lpips) / len(lpips) print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")