Commit f00ef667 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

NeRF training: avoid caching unused visualization data.

Summary: If we are not visualizing the training with visdom, then there are a couple of outputs of the coarse rendering step which are not small and are returned by the renderer but never used. We don't need to bother transferring them to the CPU.

Reviewed By: nikhilaravi

Differential Revision: D28939958

fbshipit-source-id: 7e0d6681d6524f7fb57b6b20164580006120de80
parent 7204a4ca
......@@ -64,6 +64,7 @@ class RadianceFieldRenderer(torch.nn.Module):
n_layers_xyz: int = 8,
append_xyz: Tuple[int] = (5,),
density_noise_std: float = 0.0,
visualization: bool = False,
):
"""
Args:
......@@ -102,6 +103,7 @@ class RadianceFieldRenderer(torch.nn.Module):
density_noise_std: The standard deviation of the random normal noise
added to the output of the occupancy MLP.
Active only when `self.training==True`.
visualization: whether to store extra output for visualization.
"""
super().__init__()
......@@ -159,6 +161,7 @@ class RadianceFieldRenderer(torch.nn.Module):
self._density_noise_std = density_noise_std
self._chunk_size_test = chunk_size_test
self._image_size = image_size
self.visualization = visualization
def precache_rays(
self,
......@@ -248,16 +251,15 @@ class RadianceFieldRenderer(torch.nn.Module):
else:
raise ValueError(f"No such rendering pass {renderer_pass}")
return {
"rgb_fine": rgb_fine,
"rgb_coarse": rgb_coarse,
"rgb_gt": rgb_gt,
out = {"rgb_fine": rgb_fine, "rgb_coarse": rgb_coarse, "rgb_gt": rgb_gt}
if self.visualization:
# Store the coarse rays/weights only for visualization purposes.
"coarse_ray_bundle": type(coarse_ray_bundle)(
out["coarse_ray_bundle"] = type(coarse_ray_bundle)(
*[v.detach().cpu() for k, v in coarse_ray_bundle._asdict().items()]
),
"coarse_weights": coarse_weights.detach().cpu(),
}
)
out["coarse_weights"] = coarse_weights.detach().cpu()
return out
def forward(
self,
......
......@@ -52,6 +52,7 @@ def main(cfg: DictConfig):
n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir,
n_layers_xyz=cfg.implicit_function.n_layers_xyz,
density_noise_std=cfg.implicit_function.density_noise_std,
visualization=cfg.visualization.visdom,
)
# Move the model to the relevant device.
......@@ -195,17 +196,18 @@ def main(cfg: DictConfig):
stats.print(stat_set="train")
# Update the visualization cache.
visuals_cache.append(
{
"camera": camera.cpu(),
"camera_idx": camera_idx,
"image": image.cpu().detach(),
"rgb_fine": nerf_out["rgb_fine"].cpu().detach(),
"rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(),
"rgb_gt": nerf_out["rgb_gt"].cpu().detach(),
"coarse_ray_bundle": nerf_out["coarse_ray_bundle"],
}
)
if viz is not None:
visuals_cache.append(
{
"camera": camera.cpu(),
"camera_idx": camera_idx,
"image": image.cpu().detach(),
"rgb_fine": nerf_out["rgb_fine"].cpu().detach(),
"rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(),
"rgb_gt": nerf_out["rgb_gt"].cpu().detach(),
"coarse_ray_bundle": nerf_out["coarse_ray_bundle"],
}
)
# Adjust the learning rate.
lr_scheduler.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment