Commit 9f90842b authored by Ruilong Li's avatar Ruilong Li
Browse files

cleanup training code

parent e6647a00
...@@ -24,239 +24,254 @@ from examples.utils import ( ...@@ -24,239 +24,254 @@ from examples.utils import (
) )
from nerfacc.estimators.occ_grid import OccGridEstimator from nerfacc.estimators.occ_grid import OccGridEstimator
parser = argparse.ArgumentParser() def run(args):
parser.add_argument( device = "cuda:0"
"--data_root", set_random_seed(42)
type=str,
# default=str(pathlib.Path.cwd() / "data/360_v2"),
default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="train",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
help="which scene to use",
)
args = parser.parse_args()
device = "cuda:0" if args.scene in MIPNERF360_UNBOUNDED_SCENES:
set_random_seed(42) from datasets.nerf_360_v2 import SubjectLoader
if args.scene in MIPNERF360_UNBOUNDED_SCENES: # training parameters
from datasets.nerf_360_v2 import SubjectLoader max_steps = 20000
init_batch_size = 1024
target_sample_batch_size = 1 << 18
weight_decay = 0.0
# scene parameters
aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device)
near_plane = 0.2
far_plane = 1.0e10
# dataset parameters
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
# model parameters
grid_resolution = 128
grid_nlvl = 4
# render parameters
render_step_size = 1e-3
alpha_thre = 1e-2
cone_angle = 0.004
# training parameters else:
max_steps = 20000 from datasets.nerf_synthetic import SubjectLoader
init_batch_size = 1024
target_sample_batch_size = 1 << 18
weight_decay = 0.0
# scene parameters
aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=device)
near_plane = 0.2
far_plane = 1.0e10
# dataset parameters
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
# model parameters
grid_resolution = 128
grid_nlvl = 4
# render parameters
render_step_size = 1e-3
alpha_thre = 1e-2
cone_angle = 0.004
else: # training parameters
from datasets.nerf_synthetic import SubjectLoader max_steps = 20000
init_batch_size = 1024
target_sample_batch_size = 1 << 18
weight_decay = (
1e-5 if args.scene in ["materials", "ficus", "drums"] else 1e-6
)
# scene parameters
aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device)
near_plane = 0.0
far_plane = 1.0e10
# dataset parameters
train_dataset_kwargs = {}
test_dataset_kwargs = {}
# model parameters
grid_resolution = 128
grid_nlvl = 1
# render parameters
render_step_size = 5e-3
alpha_thre = 0.0
cone_angle = 0.0
# training parameters train_dataset = SubjectLoader(
max_steps = 20000 subject_id=args.scene,
init_batch_size = 1024 root_fp=args.data_root,
target_sample_batch_size = 1 << 18 split=args.train_split,
weight_decay = ( num_rays=init_batch_size,
1e-5 if args.scene in ["materials", "ficus", "drums"] else 1e-6 device=device,
**train_dataset_kwargs,
) )
# scene parameters
aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device)
near_plane = 0.0
far_plane = 1.0e10
# dataset parameters
train_dataset_kwargs = {}
test_dataset_kwargs = {}
# model parameters
grid_resolution = 128
grid_nlvl = 1
# render parameters
render_step_size = 5e-3
alpha_thre = 0.0
cone_angle = 0.0
train_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id=args.scene, subject_id=args.scene,
root_fp=args.data_root, root_fp=args.data_root,
split=args.train_split, split="test",
num_rays=init_batch_size, num_rays=None,
device=device, device=device,
**train_dataset_kwargs, **test_dataset_kwargs,
) )
test_dataset = SubjectLoader( estimator = OccGridEstimator(
subject_id=args.scene, roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl
root_fp=args.data_root, ).to(device)
split="test",
num_rays=None,
device=device,
**test_dataset_kwargs,
)
estimator = OccGridEstimator( # setup the radiance field we want to train.
roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl grad_scaler = torch.cuda.amp.GradScaler(2**10)
).to(device) radiance_field = NGPRadianceField(aabb=estimator.aabbs[-1]).to(device)
optimizer = torch.optim.Adam(
radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay
)
scheduler = torch.optim.lr_scheduler.ChainedScheduler(
[
torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.01, total_iters=100
),
torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 9 // 10,
],
gamma=0.33,
),
]
)
lpips_net = LPIPS(net="vgg").to(device)
lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()
# setup the radiance field we want to train. # training
grad_scaler = torch.cuda.amp.GradScaler(2**10) tic = time.time()
radiance_field = NGPRadianceField(aabb=estimator.aabbs[-1]).to(device) for step in range(max_steps + 1):
optimizer = torch.optim.Adam( radiance_field.train()
radiance_field.parameters(), lr=1e-2, eps=1e-15, weight_decay=weight_decay estimator.train()
)
scheduler = torch.optim.lr_scheduler.ChainedScheduler(
[
torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.01, total_iters=100
),
torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
max_steps // 2,
max_steps * 3 // 4,
max_steps * 9 // 10,
],
gamma=0.33,
),
]
)
lpips_net = LPIPS(net="vgg").to(device)
lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()
# training i = torch.randint(0, len(train_dataset), (1,)).item()
tic = time.time() data = train_dataset[i]
for step in range(max_steps + 1):
radiance_field.train()
estimator.train()
i = torch.randint(0, len(train_dataset), (1,)).item() render_bkgd = data["color_bkgd"]
data = train_dataset[i] rays = data["rays"]
pixels = data["pixels"]
render_bkgd = data["color_bkgd"] def occ_eval_fn(x):
rays = data["rays"] density = radiance_field.query_density(x)
pixels = data["pixels"] return density * render_step_size
def occ_eval_fn(x): # update occupancy grid
density = radiance_field.query_density(x) estimator.update_every_n_steps(
return density * render_step_size step=step,
occ_eval_fn=occ_eval_fn,
occ_thre=1e-2,
)
# update occupancy grid # render
estimator.update_every_n_steps( rgb, acc, depth, n_rendering_samples = render_image_with_occgrid(
step=step, radiance_field,
occ_eval_fn=occ_eval_fn, estimator,
occ_thre=1e-2, rays,
) # rendering options
near_plane=near_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=cone_angle,
alpha_thre=alpha_thre,
)
if n_rendering_samples == 0:
continue
# render if target_sample_batch_size > 0:
rgb, acc, depth, n_rendering_samples = render_image_with_occgrid( # dynamic batch size for rays to keep sample batch size constant.
radiance_field, num_rays = len(pixels)
estimator, num_rays = int(
rays, num_rays * (target_sample_batch_size / float(n_rendering_samples))
# rendering options )
near_plane=near_plane, train_dataset.update_num_rays(num_rays)
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=cone_angle,
alpha_thre=alpha_thre,
)
if n_rendering_samples == 0:
continue
if target_sample_batch_size > 0: # compute loss
# dynamic batch size for rays to keep sample batch size constant. loss = F.smooth_l1_loss(rgb, pixels)
num_rays = len(pixels)
num_rays = int(
num_rays * (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)
# compute loss optimizer.zero_grad()
loss = F.smooth_l1_loss(rgb, pixels) # do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad() if step % 10000 == 0:
# do not unscale it because we are using Adam. elapsed_time = time.time() - tic
grad_scaler.scale(loss).backward() loss = F.mse_loss(rgb, pixels)
optimizer.step() psnr = -10.0 * torch.log(loss) / np.log(10.0)
scheduler.step() print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | psnr={psnr:.2f} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | "
f"max_depth={depth.max():.3f} | "
)
if step % 10000 == 0: if step > 0 and step % max_steps == 0:
elapsed_time = time.time() - tic # evaluation
loss = F.mse_loss(rgb, pixels) radiance_field.eval()
psnr = -10.0 * torch.log(loss) / np.log(10.0) estimator.eval()
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | " psnrs = []
f"loss={loss:.5f} | psnr={psnr:.2f} | " lpips = []
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | " with torch.no_grad():
f"max_depth={depth.max():.3f} | " for i in tqdm.tqdm(range(len(test_dataset))):
) data = test_dataset[i]
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
if step > 0 and step % max_steps == 0: # rendering
# evaluation # rgb, acc, depth, _ = render_image_with_occgrid_test(
radiance_field.eval() # 1024,
estimator.eval() # # scene
# radiance_field,
# estimator,
# rays,
# # rendering options
# near_plane=near_plane,
# render_step_size=render_step_size,
# render_bkgd=render_bkgd,
# cone_angle=cone_angle,
# alpha_thre=alpha_thre,
# )
rgb, acc, depth, _ = render_image_with_occgrid(
radiance_field,
estimator,
rays,
# rendering options
near_plane=near_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=cone_angle,
alpha_thre=alpha_thre,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
lpips.append(lpips_fn(rgb, pixels).item())
# if i == 0:
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_error.png",
# (
# (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
# ).astype(np.uint8),
# )
psnr_avg = sum(psnrs) / len(psnrs)
lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
psnrs = [] if __name__ == "__main__":
lpips = [] parser = argparse.ArgumentParser()
with torch.no_grad(): parser.add_argument(
for i in tqdm.tqdm(range(len(test_dataset))): "--data_root",
data = test_dataset[i] type=str,
render_bkgd = data["color_bkgd"] # default=str(pathlib.Path.cwd() / "data/360_v2"),
rays = data["rays"] default=str(pathlib.Path.cwd() / "data/nerf_synthetic"),
pixels = data["pixels"] help="the root dir of the dataset",
)
parser.add_argument(
"--train_split",
type=str,
default="train",
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
help="which scene to use",
)
args = parser.parse_args()
# rendering run(args)
rgb, acc, depth, _ = render_image_with_occgrid_test( \ No newline at end of file
1024,
# scene
radiance_field,
estimator,
rays,
# rendering options
near_plane=near_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=cone_angle,
alpha_thre=alpha_thre,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
lpips.append(lpips_fn(rgb, pixels).item())
# if i == 0:
# imageio.imwrite(
# "rgb_test.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_error.png",
# (
# (rgb - pixels).norm(dim=-1).cpu().numpy() * 255
# ).astype(np.uint8),
# )
psnr_avg = sum(psnrs) / len(psnrs)
lpips_avg = sum(lpips) / len(lpips)
print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
...@@ -79,38 +79,6 @@ def render_image_with_occgrid( ...@@ -79,38 +79,6 @@ def render_image_with_occgrid(
else: else:
num_rays, _ = rays_shape num_rays, _ = rays_shape
def sigma_fn(t_starts, t_ends, ray_indices):
t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
if timestamps is not None:
# dnerf
t = (
timestamps[ray_indices]
if radiance_field.training
else timestamps.expand_as(positions[:, :1])
)
sigmas = radiance_field.query_density(positions, t)
else:
sigmas = radiance_field.query_density(positions)
return sigmas.squeeze(-1)
def rgb_sigma_fn(t_starts, t_ends, ray_indices):
t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
if timestamps is not None:
# dnerf
t = (
timestamps[ray_indices]
if radiance_field.training
else timestamps.expand_as(positions[:, :1])
)
rgbs, sigmas = radiance_field(positions, t, t_dirs)
else:
rgbs, sigmas = radiance_field(positions, t_dirs)
return rgbs, sigmas.squeeze(-1)
results = [] results = []
chunk = ( chunk = (
torch.iinfo(torch.int32).max torch.iinfo(torch.int32).max
...@@ -119,9 +87,45 @@ def render_image_with_occgrid( ...@@ -119,9 +87,45 @@ def render_image_with_occgrid(
) )
for i in range(0, num_rays, chunk): for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays) chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
rays_o = chunk_rays.origins
rays_d = chunk_rays.viewdirs
def sigma_fn(t_starts, t_ends, ray_indices):
t_origins = rays_o[ray_indices]
t_dirs = rays_d[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
if timestamps is not None:
# dnerf
t = (
timestamps[ray_indices]
if radiance_field.training
else timestamps.expand_as(positions[:, :1])
)
sigmas = radiance_field.query_density(positions, t)
else:
sigmas = radiance_field.query_density(positions)
return sigmas.squeeze(-1)
def rgb_sigma_fn(t_starts, t_ends, ray_indices):
t_origins = rays_o[ray_indices]
t_dirs = rays_d[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
if timestamps is not None:
# dnerf
t = (
timestamps[ray_indices]
if radiance_field.training
else timestamps.expand_as(positions[:, :1])
)
rgbs, sigmas = radiance_field(positions, t, t_dirs)
else:
rgbs, sigmas = radiance_field(positions, t_dirs)
return rgbs, sigmas.squeeze(-1)
ray_indices, t_starts, t_ends = estimator.sampling( ray_indices, t_starts, t_ends = estimator.sampling(
chunk_rays.origins, rays_o,
chunk_rays.viewdirs, rays_d,
sigma_fn=sigma_fn, sigma_fn=sigma_fn,
near_plane=near_plane, near_plane=near_plane,
far_plane=far_plane, far_plane=far_plane,
...@@ -134,7 +138,7 @@ def render_image_with_occgrid( ...@@ -134,7 +138,7 @@ def render_image_with_occgrid(
t_starts, t_starts,
t_ends, t_ends,
ray_indices, ray_indices,
n_rays=chunk_rays.origins.shape[0], n_rays=rays_o.shape[0],
rgb_sigma_fn=rgb_sigma_fn, rgb_sigma_fn=rgb_sigma_fn,
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment