Commit 9f90842b authored by Ruilong Li's avatar Ruilong Li
Browse files

cleanup training code

parent e6647a00
...@@ -24,34 +24,11 @@ from examples.utils import ( ...@@ -24,34 +24,11 @@ from examples.utils import (
) )
from nerfacc.estimators.occ_grid import OccGridEstimator from nerfacc.estimators.occ_grid import OccGridEstimator
parser = argparse.ArgumentParser() def run(args):
parser.add_argument( device = "cuda:0"
"--data_root", set_random_seed(42)
type=str,
# default=str(pathlib.Path.cwd() / "data/360_v2"),
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="train",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
help="which scene to use",
)
args = parser.parse_args()
device = "cuda:0" if args.scene in MIPNERF360_UNBOUNDED_SCENES:
set_random_seed(42)
if args.scene in MIPNERF360_UNBOUNDED_SCENES:
from datasets.nerf_360_v2 import SubjectLoader from datasets.nerf_360_v2 import SubjectLoader
# training parameters # training parameters
...@@ -74,7 +51,7 @@ if args.scene in MIPNERF360_UNBOUNDED_SCENES: ...@@ -74,7 +51,7 @@ if args.scene in MIPNERF360_UNBOUNDED_SCENES:
alpha_thre = 1e-2 alpha_thre = 1e-2
cone_angle = 0.004 cone_angle = 0.004
else: else:
from datasets.nerf_synthetic import SubjectLoader from datasets.nerf_synthetic import SubjectLoader
# training parameters # training parameters
...@@ -99,35 +76,35 @@ else: ...@@ -99,35 +76,35 @@ else:
alpha_thre = 0.0 alpha_thre = 0.0
cone_angle = 0.0 cone_angle = 0.0
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=args.data_root, root_fp=args.data_root,
split=args.train_split, split=args.train_split,
num_rays=init_batch_size, num_rays=init_batch_size,
device=device, device=device,
**train_dataset_kwargs, **train_dataset_kwargs,
) )
test_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=args.data_root, root_fp=args.data_root,
split="test", split="test",
num_rays=None, num_rays=None,
device=device, device=device,
**test_dataset_kwargs, **test_dataset_kwargs,
) )
estimator = OccGridEstimator( estimator = OccGridEstimator(
roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl
).to(device) ).to(device)
# setup the radiance field we want to train. # setup the radiance field we want to train.
grad_scaler = torch.cuda.amp.GradScaler(2**10) grad_scaler = torch.cuda.amp.GradScaler(2**10)
radiance_field = NGPRadianceField(aabb=estimator.aabbs[-1]).to(device) radiance_field = NGPRadianceField(aabb=estimator.aabbs[-1]).to(device)
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay
) )
scheduler = torch.optim.lr_scheduler.ChainedScheduler( scheduler = torch.optim.lr_scheduler.ChainedScheduler(
[ [
torch.optim.lr_scheduler.LinearLR( torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.01, total_iters=100 optimizer, start_factor=0.01, total_iters=100
...@@ -142,14 +119,14 @@ scheduler = torch.optim.lr_scheduler.ChainedScheduler( ...@@ -142,14 +119,14 @@ scheduler = torch.optim.lr_scheduler.ChainedScheduler(
gamma=0.33, gamma=0.33,
), ),
] ]
) )
lpips_net = LPIPS(net="vgg").to(device) lpips_net = LPIPS(net="vgg").to(device)
lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1 lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean() lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()
# training # training
tic = time.time() tic = time.time()
for step in range(max_steps + 1): for step in range(max_steps + 1):
radiance_field.train() radiance_field.train()
estimator.train() estimator.train()
...@@ -229,9 +206,20 @@ for step in range(max_steps + 1): ...@@ -229,9 +206,20 @@ for step in range(max_steps + 1):
pixels = data["pixels"] pixels = data["pixels"]
# rendering # rendering
rgb, acc, depth, _ = render_image_with_occgrid_test( # rgb, acc, depth, _ = render_image_with_occgrid_test(
1024, # 1024,
# scene # # scene
# radiance_field,
# estimator,
# rays,
# # rendering options
# near_plane=near_plane,
# render_step_size=render_step_size,
# render_bkgd=render_bkgd,
# cone_angle=cone_angle,
# alpha_thre=alpha_thre,
# )
rgb, acc, depth, _ = render_image_with_occgrid(
radiance_field, radiance_field,
estimator, estimator,
rays, rays,
...@@ -260,3 +248,30 @@ for step in range(max_steps + 1): ...@@ -260,3 +248,30 @@ for step in range(max_steps + 1):
psnr_avg = sum(psnrs) / len(psnrs) psnr_avg = sum(psnrs) / len(psnrs)
lpips_avg = sum(lpips) / len(lpips) lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}") print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
# default=str(pathlib.Path.cwd() / "data/360_v2"),
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="train",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
help="which scene to use",
)
args = parser.parse_args()
run(args)
\ No newline at end of file
...@@ -79,9 +79,21 @@ def render_image_with_occgrid( ...@@ -79,9 +79,21 @@ def render_image_with_occgrid(
else: else:
num_rays, _ = rays_shape num_rays, _ = rays_shape
results = []
chunk = (
torch.iinfo(torch.int32).max
if radiance_field.training
else test_chunk_size
)
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
rays_o = chunk_rays.origins
rays_d = chunk_rays.viewdirs
def sigma_fn(t_starts, t_ends, ray_indices): def sigma_fn(t_starts, t_ends, ray_indices):
t_origins = chunk_rays.origins[ray_indices] t_origins = rays_o[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices] t_dirs = rays_d[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
if timestamps is not None: if timestamps is not None:
# dnerf # dnerf
...@@ -96,8 +108,8 @@ def render_image_with_occgrid( ...@@ -96,8 +108,8 @@ def render_image_with_occgrid(
return sigmas.squeeze(-1) return sigmas.squeeze(-1)
def rgb_sigma_fn(t_starts, t_ends, ray_indices): def rgb_sigma_fn(t_starts, t_ends, ray_indices):
t_origins = chunk_rays.origins[ray_indices] t_origins = rays_o[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices] t_dirs = rays_d[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
if timestamps is not None: if timestamps is not None:
# dnerf # dnerf
...@@ -111,17 +123,9 @@ def render_image_with_occgrid( ...@@ -111,17 +123,9 @@ def render_image_with_occgrid(
rgbs, sigmas = radiance_field(positions, t_dirs) rgbs, sigmas = radiance_field(positions, t_dirs)
return rgbs, sigmas.squeeze(-1) return rgbs, sigmas.squeeze(-1)
results = []
chunk = (
torch.iinfo(torch.int32).max
if radiance_field.training
else test_chunk_size
)
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
ray_indices, t_starts, t_ends = estimator.sampling( ray_indices, t_starts, t_ends = estimator.sampling(
chunk_rays.origins, rays_o,
chunk_rays.viewdirs, rays_d,
sigma_fn=sigma_fn, sigma_fn=sigma_fn,
near_plane=near_plane, near_plane=near_plane,
far_plane=far_plane, far_plane=far_plane,
...@@ -134,7 +138,7 @@ def render_image_with_occgrid( ...@@ -134,7 +138,7 @@ def render_image_with_occgrid(
t_starts, t_starts,
t_ends, t_ends,
ray_indices, ray_indices,
n_rays=chunk_rays.origins.shape[0], n_rays=rays_o.shape[0],
rgb_sigma_fn=rgb_sigma_fn, rgb_sigma_fn=rgb_sigma_fn,
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment